From 4772f9afc2c6880e7e8d5190858d314a7a68e232 Mon Sep 17 00:00:00 2001 From: Jialei A Wang Date: Fri, 16 Aug 2024 01:31:49 +0000 Subject: [PATCH 01/38] introduce benchgc for correctness check --- .github/workflows/build.yml | 12 + .github/workflows/style.yml | 21 + .gitignore | 1 + CMakeLists.txt | 1 + scripts/correctness.sh | 106 +++ test/CMakeLists.txt | 4 + test/benchgc/.gitignore | 5 + test/benchgc/CMakeLists.txt | 41 + test/benchgc/README.md | 257 ++++++ test/benchgc/cases/generic.mlir | 15 + test/benchgc/cases/llama2.mlir | 113 +++ test/benchgc/cases/reduce.mlir | 12 + test/benchgc/setup.py | 30 + test/benchgc/src/benchgc/CMakeLists.txt | 22 + test/benchgc/src/benchgc/__init__.py | 20 + test/benchgc/src/benchgc/__main__.py | 274 ++++++ test/benchgc/src/benchgc/arg/CMakeLists.txt | 22 + test/benchgc/src/benchgc/arg/__init__.py | 163 ++++ test/benchgc/src/benchgc/arg/arg.py | 54 ++ test/benchgc/src/benchgc/arg/binary.py | 98 ++ test/benchgc/src/benchgc/arg/compare.py | 137 +++ test/benchgc/src/benchgc/arg/conv.py | 189 ++++ test/benchgc/src/benchgc/arg/eltwise.py | 177 ++++ test/benchgc/src/benchgc/arg/matmul.py | 173 ++++ test/benchgc/src/benchgc/arg/pool.py | 97 ++ test/benchgc/src/benchgc/arg/reduce.py | 78 ++ test/benchgc/src/benchgc/arg/softmax.py | 93 ++ test/benchgc/src/benchgc/arith/CMakeLists.txt | 22 + test/benchgc/src/benchgc/arith/__init__.py | 45 + test/benchgc/src/benchgc/arith/basic.py | 58 ++ .../benchgc/src/benchgc/linalg/CMakeLists.txt | 22 + test/benchgc/src/benchgc/linalg/__init__.py | 52 ++ test/benchgc/src/benchgc/linalg/binary.py | 137 +++ test/benchgc/src/benchgc/linalg/conv.py | 834 ++++++++++++++++++ test/benchgc/src/benchgc/linalg/eltwise.py | 197 +++++ test/benchgc/src/benchgc/linalg/generic.py | 236 +++++ test/benchgc/src/benchgc/linalg/matmul.py | 317 +++++++ test/benchgc/src/benchgc/linalg/misc.py | 97 ++ test/benchgc/src/benchgc/linalg/pool.py | 489 ++++++++++ test/benchgc/src/benchgc/linalg/softmax.py | 47 + test/benchgc/src/benchgc/mlir/CMakeLists.txt | 22 + test/benchgc/src/benchgc/mlir/__init__.py | 15 + test/benchgc/src/benchgc/mlir/arg.py | 174 ++++ test/benchgc/src/benchgc/mlir/module.py | 48 + test/benchgc/src/benchgc/mlir/util.py | 111 +++ test/benchgc/src/benchgc/runner.py | 109 +++ .../benchgc/src/benchgc/tensor/CMakeLists.txt | 22 + test/benchgc/src/benchgc/tensor/__init__.py | 45 + test/benchgc/src/benchgc/tensor/basic.py | 33 + test/benchgc/src/benchgc/tensor/shape.py | 59 ++ test/benchgc/src/benchgc/util.py | 341 +++++++ 51 files changed, 5747 insertions(+) create mode 100755 scripts/correctness.sh create mode 100644 test/benchgc/.gitignore create mode 100644 test/benchgc/CMakeLists.txt create mode 100644 test/benchgc/README.md create mode 100644 test/benchgc/cases/generic.mlir create mode 100644 test/benchgc/cases/llama2.mlir create mode 100644 test/benchgc/cases/reduce.mlir create mode 100644 test/benchgc/setup.py create mode 100644 test/benchgc/src/benchgc/CMakeLists.txt create mode 100644 test/benchgc/src/benchgc/__init__.py create mode 100644 test/benchgc/src/benchgc/__main__.py create mode 100644 test/benchgc/src/benchgc/arg/CMakeLists.txt create mode 100644 test/benchgc/src/benchgc/arg/__init__.py create mode 100644 test/benchgc/src/benchgc/arg/arg.py create mode 100644 test/benchgc/src/benchgc/arg/binary.py create mode 100644 test/benchgc/src/benchgc/arg/compare.py create mode 100644 test/benchgc/src/benchgc/arg/conv.py create mode 100644 test/benchgc/src/benchgc/arg/eltwise.py create mode 100644 test/benchgc/src/benchgc/arg/matmul.py create mode 100644 test/benchgc/src/benchgc/arg/pool.py create mode 100644 test/benchgc/src/benchgc/arg/reduce.py create mode 100644 test/benchgc/src/benchgc/arg/softmax.py create mode 100644 test/benchgc/src/benchgc/arith/CMakeLists.txt create mode 100644 test/benchgc/src/benchgc/arith/__init__.py create mode 100644 test/benchgc/src/benchgc/arith/basic.py create mode 100644 test/benchgc/src/benchgc/linalg/CMakeLists.txt create mode 100644 test/benchgc/src/benchgc/linalg/__init__.py create mode 100644 test/benchgc/src/benchgc/linalg/binary.py create mode 100644 test/benchgc/src/benchgc/linalg/conv.py create mode 100644 test/benchgc/src/benchgc/linalg/eltwise.py create mode 100644 test/benchgc/src/benchgc/linalg/generic.py create mode 100644 test/benchgc/src/benchgc/linalg/matmul.py create mode 100644 test/benchgc/src/benchgc/linalg/misc.py create mode 100644 test/benchgc/src/benchgc/linalg/pool.py create mode 100644 test/benchgc/src/benchgc/linalg/softmax.py create mode 100644 test/benchgc/src/benchgc/mlir/CMakeLists.txt create mode 100644 test/benchgc/src/benchgc/mlir/__init__.py create mode 100644 test/benchgc/src/benchgc/mlir/arg.py create mode 100644 test/benchgc/src/benchgc/mlir/module.py create mode 100644 test/benchgc/src/benchgc/mlir/util.py create mode 100644 test/benchgc/src/benchgc/runner.py create mode 100644 test/benchgc/src/benchgc/tensor/CMakeLists.txt create mode 100644 test/benchgc/src/benchgc/tensor/__init__.py create mode 100644 test/benchgc/src/benchgc/tensor/basic.py create mode 100644 test/benchgc/src/benchgc/tensor/shape.py create mode 100644 test/benchgc/src/benchgc/util.py diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 102d3906a..0ca233a76 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -45,3 +45,15 @@ jobs: - name: Test run: | cmake --build build --target gc-check + + - name: Build and install benchgc + working-directory: build + run: | + ninja benchgc + pip uninstall -y benchgc || true + pip install test/benchgc/dist/benchgc-*.whl + - name: Correctness Test + env: + LD_PRELOAD: /lib/x86_64-linux-gnu/libomp5.so + run: | + scripts/correctness.sh \ No newline at end of file diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index 4c66e1f14..91efdebf6 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -22,3 +22,24 @@ jobs: run: | clang-format --version find . -name *.cpp -or -name *.hpp | xargs clang-format --dry-run --Werror -style=file + + python_format: + runs-on: ubuntu-latest + steps: + - name: checkout base version + uses: actions/checkout@v4 + with: + fetch-depth: 100 + ref: ${{ github.event.pull_request.base.sha }} + + - name: checkout head version + uses: actions/checkout@v4 + with: + fetch-depth: 100 + ref: ${{ github.event.pull_request.head.sha }} + + - name: install darker + run: "python3 -m pip install darker darker[isort] darker[flynt]" + + - name: check python format + run: "python3 -m darker --check -i -f --diff -r `git merge-base ${{ github.event.pull_request.base.sha }} ${{ github.event.pull_request.head.sha }}`...HEAD ." \ No newline at end of file diff --git a/.gitignore b/.gitignore index e1fe789da..40c724cd9 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ build/ externals/ compile_commands.json +__pycache__ \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 90a89666f..d9dcd7313 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -49,6 +49,7 @@ option(GC_ENABLE_IMEX "Enable Intel® Extension for MLIR" OFF) option(GC_ENABLE_BINDINGS_PYTHON "Enable Graph Complier Python Binding" ON) option(GC_DEV_LINK_LLVM_DYLIB "Link dynamic libraries of LLVM and MLIR. For developers only. Do not use it in packing the library." OFF) option(GC_ENABLE_RUNTIME_NAIVE_BRGEMM "Use naive BRGEMM as runtime backend for debug purpose." OFF) +option(GC_BENCH_ENABLE "Build benchgc." ON) if(GC_ENABLE_LEGACY) add_subdirectory(legacy/core) diff --git a/scripts/correctness.sh b/scripts/correctness.sh new file mode 100755 index 000000000..c0ae008ce --- /dev/null +++ b/scripts/correctness.sh @@ -0,0 +1,106 @@ +#! /bin/bash + +export CASE_DIR=$(pwd)/test/benchgc/cases + +FAIL=0 +set -e + +# bf16 +python3 -m benchgc --verbose 0 --driver linalg --case matmul --md 0:32x128xbf16 --md 1:128x64xbf16 --md 2:32x64xbf16 --cast cast_signed || FAIL=1 + +# f32 + +# misc +python3 -m benchgc --verbose 0 --driver linalg --case fill --md 0:f32 --md 1:32x4096xf32 --cmp 1:P:0:0 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case copy --md 0:1024x1024xf32 --md 1:1024x1024xbf16 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case broadcast --md 0:1024xf32 --md 1:2x32x1024xf32 --dimensions=0 --dimensions=1 || FAIL=1 + +# matmul +python3 -m benchgc --verbose 0 --driver linalg --case batch_matmul --md 0:16x512x64xf32 --md 1:16x64x32xf32 --md 2:16x512x32xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case batch_matmul_transpose_a --md 0:16x512x64xf32 --md 1:16x512x32xf32 --md 2:16x64x32xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case batch_matmul_transpose_b --md 0:16x512x64xf32 --md 1:16x128x64xf32 --md 2:16x512x128xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case batch_matvec --md 0:16x512x64xf32 --md 1:16x64xf32 --md 2:16x512xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case batch_mmt4d --md 0:4x4x8x4x2xf32 --md 1:4x8x8x4x2xf32 --md 2:4x4x8x4x4xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case batch_reduce_matmul --md 0:16x512x64xf32 --md 1:16x64x32xf32 --md 2:512x32xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case batch_vecmat --md 0:16x64xf32 --md 1:16x64x512xf32 --md 2:16x512xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case dot --md 0:4096xf32 --md 1:4096xf32 --md 2:0xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case matmul --md 0:1024x512xf32 --md 1:512x512xf32 --md 2:1024x512xf32 --cast cast_signed || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case matmul_transpose_a --md 0:1024x512xf32 --md 1:1024x512xf32 --md 2:512x512xf32 --cast cast_signed || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case matmul_transpose_b --md 0:1024x512xf32 --md 1:1024x512xf32 --md 2:1024x1024xf32 --cast cast_signed || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case matvec --md 0:512x64xf32 --md 1:64xf32 --md 2:512xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case mmt4d --md 0:4x8x4x2xf32 --md 1:8x8x4x2xf32 --md 2:4x8x4x4xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case vecmat --md 0:512xf32 --md 1:512x64xf32 --md 2:64xf32 || FAIL=1 + +# binary +python3 -m benchgc --verbose 0 --driver linalg --case add --md 0:1x32x4096xf32 --md 1:1x32x4096xf32 --md 2:1x32x4096xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case sub --md 0:1x32x4096xf32 --md 1:1x32x4096xf32 --md 2:1x32x4096xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case mul --md 0:1x32x4096xf32 --md 1:1x32x4096xf32 --md 2:1x32x4096xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case div --md 0:1x32x4096xf32 --md 1:1x32x4096xf32 --md 2:1x32x4096xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case max --md 0:1024x1024xf32 --md 1:1024x1024xf32 --md 2:1024x1024xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case min --md 0:1024x1024xf32 --md 1:1024x1024xf32 --md 2:1024x1024xf32 || FAIL=1 + +# element wise +python3 -m benchgc --verbose 0 --driver linalg --case abs --md 0:32x4096xf32 --md 1:32x4096xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case ceil --md 0:32x4096xf32 --md 1:32x4096xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case erf --md 0:1024x512xf32 --md 1:1024x512xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case floor --md 0:32x4096xf32 --md 1:32x4096xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case log --md 0:4096x32xf32 --md 1:4096x32xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case negf --md 0:32x4096xf32 --md 1:32x4096xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case exp --md 0:32x4096xf32 --md 1:32x4096xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case round --md 0:32x4096xf32 --md 1:32x4096xf32 || FAIL=1 +# python3 -m benchgc --verbose 0 --driver linalg --case rsqrt --md 0:32x4096xf32 --md 1:32x4096xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case sqrt --md 0:32x4096xf32 --md 1:32x4096xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case square --md 0:32x4096xf32 --md 1:32x4096xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case tanh --md 0:128x128xf32 --md 1:128x128xf32 || FAIL=1 + +# conv +python3 -m benchgc --verbose 0 --driver linalg --case conv_1d_ncw_fcw --md 0:4x4x32xf32 --md 1:8x4x4xf32 --md 2:4x8x13xf32 --strides 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case conv_1d_nwc_wcf --md 0:4x32x4xf32 --md 1:4x4x8xf32 --md 2:4x13x8xf32 --strides 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case conv_1d --md 0:32xf32 --md 1:4xf32 --md 2:29xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case conv_2d_nchw_fchw --md 0:4x4x32x32xf32 --md 1:8x4x4x4xf32 --md 2:4x8x13x13xf32 --strides 2 --strides 2 --dilations 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case conv_2d_ngchw_fgchw --md 0:4x2x2x32x32xf32 --md 1:4x2x2x4x4xf32 --md 2:4x2x4x13x13xf32 --strides 2 --strides 2 --dilations 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case conv_2d_ngchw_gfchw --md 0:4x2x2x32x32xf32 --md 1:2x4x2x4x4xf32 --md 2:4x2x4x13x13xf32 --strides 2 --strides 2 --dilations 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case conv_2d_nhwc_fhwc --md 0:4x32x32x4xf32 --md 1:8x4x4x4xf32 --md 2:4x13x13x8xf32 --strides 2 --strides 2 --dilations 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case conv_2d_nhwc_hwcf --md 0:4x32x32x4xf32 --md 1:4x4x4x8xf32 --md 2:4x13x13x8xf32 --strides 2 --strides 2 --dilations 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case conv_2d --md 0:32x32xf32 --md 1:4x4xf32 --md 2:29x29xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case conv_3d_ncdhw_fcdhw --md 0:4x4x32x32x32xf32 --md 1:8x4x4x4x4xf32 --md 2:4x8x13x13x13xf32 --strides 2 --strides 2 --strides 2 --dilations 2 --dilations 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case conv_3d_ndhwc_dhwcf --md 0:4x32x32x32x4xf32 --md 1:4x4x4x4x8xf32 --md 2:4x13x13x13x8xf32 --strides 2 --strides 2 --strides 2 --dilations 2 --dilations 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case conv_3d --md 0:32x32x32xf32 --md 1:4x4x4xf32 --md 2:29x29x29xf32 || FAIL=1 + +# depthwise conv +python3 -m benchgc --verbose 0 --driver linalg --case depthwise_conv_1d_ncw_cw --md 0:4x4x32xf32 --md 1:4x4xf32 --md 2:4x4x13xf32 --strides 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case depthwise_conv_1d_nwc_wc --md 0:4x32x4xf32 --md 1:4x4xf32 --md 2:4x13x4xf32 --strides 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case depthwise_conv_1d_nwc_wcm --md 0:4x32x4xf32 --md 1:4x4x3xf32 --md 2:4x13x4x3xf32 --strides 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case depthwise_conv_2d_nchw_chw --md 0:4x4x32x32xf32 --md 1:4x4x4xf32 --md 2:4x4x13x13xf32 --strides 2 --strides 2 --dilations 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case depthwise_conv_2d_nhwc_hwc --md 0:4x32x32x4xf32 --md 1:4x4x4xf32 --md 2:4x13x13x4xf32 --strides 2 --strides 2 --dilations 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case depthwise_conv_2d_nhwc_hwcm --md 0:4x32x32x4xf32 --md 1:4x4x4x3xf32 --md 2:4x13x13x4x3xf32 --strides 2 --strides 2 --dilations 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case depthwise_conv_3d_ncdhw_cdhw --md 0:4x4x32x32x32xf32 --md 1:4x4x4x4xf32 --md 2:4x4x13x13x13xf32 --strides 2 --strides 2 --strides 2 --dilations 2 --dilations 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case depthwise_conv_3d_ndhwc_dhwc --md 0:4x32x32x32x4xf32 --md 1:4x4x4x4xf32 --md 2:4x13x13x13x4xf32 --strides 2 --strides 2 --strides 2 --dilations 2 --dilations 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case depthwise_conv_3d_ndhwc_dhwcm --md 0:4x32x32x32x4xf32 --md 1:4x4x4x4x3xf32 --md 2:4x13x13x13x4x3xf32 --strides 2 --strides 2 --strides 2 --dilations 2 --dilations 2 --dilations 2 || FAIL=1 + +# pool +python3 -m benchgc --verbose 0 --driver linalg --case pooling_nchw_max --md 0:4x4x32x32xf32 --md 1:4x4xf32 --md 2:4x4x13x13xf32 --strides 2 --strides 2 --dilations 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case pooling_nchw_sum --md 0:4x4x32x32xf32 --md 1:4x4xf32 --md 2:4x4x13x13xf32 --strides 2 --strides 2 --dilations 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case pooling_ncw_max --md 0:4x4x32xf32 --md 1:4xf32 --md 2:4x4x13xf32 --strides 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case pooling_ncw_sum --md 0:4x4x32xf32 --md 1:4xf32 --md 2:4x4x13xf32 --strides 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case pooling_ndhwc_max --md 0:4x32x32x32x4xf32 --md 1:4x4x4xf32 --md 2:4x13x13x13x4xf32 --strides 2 --strides 2 --strides 2 --dilations 2 --dilations 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case pooling_ndhwc_sum --md 0:4x32x32x32x4xf32 --md 1:4x4x4xf32 --md 2:4x13x13x13x4xf32 --strides 2 --strides 2 --strides 2 --dilations 2 --dilations 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case pooling_nhwc_max --md 0:4x32x32x4xf32 --md 1:4x4xf32 --md 2:4x13x13x4xf32 --strides 2 --strides 2 --dilations 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case pooling_nhwc_sum --md 0:4x32x32x4xf32 --md 1:4x4xf32 --md 2:4x13x13x4xf32 --strides 2 --strides 2 --dilations 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case pooling_nhwc_min --md 0:4x32x32x4xf32 --md 1:4x4xf32 --md 2:4x13x13x4xf32 --strides 2 --strides 2 --dilations 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case pooling_nwc_max --md 0:4x32x4xf32 --md 1:4xf32 --md 2:4x13x4xf32 --strides 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case pooling_nwc_sum --md 0:4x32x4xf32 --md 1:4xf32 --md 2:4x13x4xf32 --strides 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case pooling_nwc_min --md 0:4x32x4xf32 --md 1:4xf32 --md 2:4x13x4xf32 --strides 2 --dilations 2 || FAIL=1 + +# generic / reduce +python3 -m benchgc --verbose 0 --driver mlir --case ${CASE_DIR}/generic.mlir || FAIL=1 +python3 -m benchgc --verbose 0 --driver mlir --case ${CASE_DIR}/reduce.mlir || FAIL=1 + +# softmax +# python3 -m benchgc --verbose 0 --driver linalg --case softmax --md 0:32x4096xf32 --md 1:32x4096xf32 --dimension 1 || FAIL=1 + +# mlir +# python3 -m benchgc --verbose 0 --driver mlir --case ${CASE_DIR}/llama2.mlir || FAIL=1 + +set +e +exit $FAIL \ No newline at end of file diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 4e4036b19..4baaa28de 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -6,3 +6,7 @@ endif () include(gtest) add_subdirectory(dnnl) add_subdirectory(mlir) + +if(GC_BENCH_ENABLE) + add_subdirectory(benchgc) +endif() \ No newline at end of file diff --git a/test/benchgc/.gitignore b/test/benchgc/.gitignore new file mode 100644 index 000000000..0fcd2be1e --- /dev/null +++ b/test/benchgc/.gitignore @@ -0,0 +1,5 @@ +dist/ +src/benchgc.egg-info/ +build +benchgc.egg-info/ +__pycache__ diff --git a/test/benchgc/CMakeLists.txt b/test/benchgc/CMakeLists.txt new file mode 100644 index 000000000..e50f35cf2 --- /dev/null +++ b/test/benchgc/CMakeLists.txt @@ -0,0 +1,41 @@ +################################################################################ +# Copyright (C) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# SPDX-License-Identifier: Apache-2.0 +################################################################################ + +if(NOT GC_BENCH_ENABLE) + message(STATUS "Benchgc is not enabled") + return() +endif() + +configure_file(setup.py ${CMAKE_BINARY_DIR}/test/benchgc/setup.py COPYONLY) + +file(GLOB PYTHON_SCRIPTS "*.py") +foreach(PY_SCRIPT ${PYTHON_SCRIPTS}) + configure_file(${PY_SCRIPT} ${CMAKE_BINARY_DIR} COPYONLY) +endforeach() + +find_package(Python 3.10 REQUIRED COMPONENTS Interpreter NumPy REQUIRED) +add_custom_target(benchgc + COMMAND ${Python_EXECUTABLE} setup.py bdist_wheel + WORKING_DIRECTORY "${CMAKE_BINARY_DIR}/test/benchgc/" + DEPENDS GcPythonModules) + +add_subdirectory("src/benchgc") +add_subdirectory("src/benchgc/arg") +add_subdirectory("src/benchgc/mlir") +add_subdirectory("src/benchgc/linalg") +add_subdirectory("src/benchgc/tensor") +add_subdirectory("src/benchgc/arith") diff --git a/test/benchgc/README.md b/test/benchgc/README.md new file mode 100644 index 000000000..77499c5fd --- /dev/null +++ b/test/benchgc/README.md @@ -0,0 +1,257 @@ +# benchgc - benchmark tool for graph compiler + +## Description + +Benchgc is a tool used to verify the correctness and performance of graph compiler. Benchgc accepts MLIR files based on the OneDNN graph dialect as test cases and prepares test data for them. For correctness verification, Benchgc will use PyTorch as a reference for comparison. + +## Prerequisite +* python >= 3.10 +* torch >= 2.2 +* pybind11 + +## Build and install +``` +# Please execute at the top level of the project + +mkdir -p build +cd build + +cmake .. -DMLIR_DIR=$MLIR_PATH -DGC_TEST_ENABLE=ON -DGC_ENABLE_BINDINGS_PYTHON=ON -DGC_BENCH_ENABLE=ON +make -j benchgc + +python -m pip install test/benchgc/dist/benchgc-*.whl + +``` + +## Synopsis +``` +python -m benchgc [OPTIONS] --driver [DRIVER] --case [CASE] +``` +## Flags +### --driver [str] +* linalg: test the single op in linalg dialect +* mlir: upload a mlir file and run +* pattern: predefined pattern test such as mlp + +### --case [str] +* if driver=mlir, please provide a mlir file here to test +* if driver=pattern, please provide the pre-defined pattern name, such as mlp here +* if driver is a dialect name, please provide the detail op name to start a single op test + +### --seed [int] +* set the seed to generate the test data and reprodce the test + +### --verbose [int] +* set the verbose level + +### --md index:SHAPExTYPE +* Describe the shape and data type for argument +* Not available when driver=mlir +* index means the order of argument, including both inputs and outs +* use prefix `0x` (e.g. `0xbf16`) to represent 0d memref or tensor input +* use data type directly (e.g.`f32`) to represent a normal scalar + +``` +# %arg0 -> index = 0 +# tensor<2x2x2xf32> -> index = 1 + +module { + func.func @entry(%arg0: f32) -> tensor<2x2x2xf32> attributes {llvm.emit_c_interface} { + %0 = tensor.empty() : tensor<2x2x2xf32> + %1 = linalg.fill ins(%arg0 : f32) outs(%0 : tensor<2x2x2xf32>) -> tensor<2x2x2xf32> + return %1 : tensor<2x2x2xf32> + } +} +``` + +### --fill index:fill_type:[:fill_parameter]* +* If not set, benchgc will assign a default method for the argument + +| description | fill_type | fill_parameter | +|-------------|-----------|-----------| +| Zero | Z | | +| Normal | N | mean, std | +| Poisson | P | lambda | +| Binomial | B | n, p | +| Uniform | U | a, b | +| Integer | I | a, b | +| Pytorch tensor dump | F | dump filename | +| Benchdnn driver | D | driver_name[:driver filling parameter]* | + +#### Benchdnn driver filling + +| driver_name | driver filling parameter | +|-------------|--------------------------| +| binary | src0/src1:src0 dtype:src1 dtype:dst dtype | +| conv | src/wei:src dtype:wei dtype:dst dtype:amplifier | +| eltwise | algorithm: alpha: beta (please check https://oneapi-src.github.io/oneDNN/dev_guide_eltwise.html) | +| matmul | src/wei:src dtype:wei dtype:dst dtype:amplifier | +| pool | not required | + +### --cmp index:cmp_type:[:cmp_parameter]* +* If not set, benchgc will assign a default method for the argument + +| description | cmp_type | cmp_parameter | +|-------------|-----------|-----------| +| P2P check | P | threshold, zero_percent(mistrust check) | +| Norm check | N | threshold | +| Benchdnn driver | D | driver_name:dtype:case | + +## Example +``` +# single add op test +# using the same data filling / compare strategy as the benchdnn primitive driver if not set +python3 -m benchgc --verbose 6 --driver linalg --case add --md 0:4x5xf32 --md 1:4x5xf32 --md 2:4x5xf32 + +arg0 shape: [4, 5] dtype: f32 fill_type: D fill_param: ['binary', 'src0', 'f32', 'f32', 'f32'] cmp_type: D cmp_param: ['binary', 'f32', 'add'] +arg1 shape: [4, 5] dtype: f32 fill_type: D fill_param: ['binary', 'src1', 'f32', 'f32', 'f32'] cmp_type: D cmp_param: ['binary', 'f32', 'add'] +arg2 shape: [4, 5] dtype: f32 fill_type: Z fill_param: [] cmp_type: D cmp_param: ['binary', 'f32', 'add'] +module { + func.func @entry(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5xf32>) -> tensor<4x5xf32> attributes {llvm.emit_c_interface} { + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<4x5xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<4x5xf32>) -> tensor<4x5xf32> + %2 = linalg.add ins(%arg0, %arg1 : tensor<4x5xf32>, tensor<4x5xf32>) outs(%1 : tensor<4x5xf32>) -> tensor<4x5xf32> + return %2 : tensor<4x5xf32> + } +} + +fill arg0: +tensor([[ -5.0000, 10.0000, 3.7500, -2.5000, -8.7500], + [ 6.2500, 0.0000, -6.2500, 8.7500, 2.5000], + [ -3.7500, -10.0000, 5.0000, -1.2500, -7.5000], + [ 7.5000, 1.2500, -5.0000, 10.0000, 3.7500]]) +fill arg1: +tensor([[ 1.2500, -5.0000, 10.0000, 3.7500, -2.5000], + [ -8.7500, 6.2500, 1.0000, -6.2500, 8.7500], + [ 2.5000, -3.7500, -10.0000, 5.0000, -1.2500], + [ -7.5000, 7.5000, 1.2500, -5.0000, 10.0000]]) +fill arg2: +tensor([[0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.]]) +p2p check: threshold: 0.0000001 + (0, 0): ref: -5.0000000 res: -5.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (0, 1): ref: 10.0000000 res: 10.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (0, 2): ref: 3.7500000 res: 3.7500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (0, 3): ref: -2.5000000 res: -2.5000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (0, 4): ref: -8.7500000 res: -8.7500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 0): ref: 6.2500000 res: 6.2500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 1): ref: 0.0000000 res: 0.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 2): ref: -6.2500000 res: -6.2500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 3): ref: 8.7500000 res: 8.7500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 4): ref: 2.5000000 res: 2.5000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (2, 0): ref: -3.7500000 res: -3.7500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (2, 1): ref: -10.0000000 res: -10.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (2, 2): ref: 5.0000000 res: 5.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (2, 3): ref: -1.2500000 res: -1.2500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (2, 4): ref: -7.5000000 res: -7.5000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (3, 0): ref: 7.5000000 res: 7.5000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (3, 1): ref: 1.2500000 res: 1.2500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (3, 2): ref: -5.0000000 res: -5.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (3, 3): ref: 10.0000000 res: 10.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (3, 4): ref: 3.7500000 res: 3.7500000 abs_diff: 0.0000000 rel_diff: 0.0000000 +p2p check: threshold: 0.0000001 + (0, 0): ref: 1.2500000 res: 1.2500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (0, 1): ref: -5.0000000 res: -5.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (0, 2): ref: 10.0000000 res: 10.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (0, 3): ref: 3.7500000 res: 3.7500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (0, 4): ref: -2.5000000 res: -2.5000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 0): ref: -8.7500000 res: -8.7500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 1): ref: 6.2500000 res: 6.2500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 2): ref: 1.0000000 res: 1.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 3): ref: -6.2500000 res: -6.2500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 4): ref: 8.7500000 res: 8.7500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (2, 0): ref: 2.5000000 res: 2.5000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (2, 1): ref: -3.7500000 res: -3.7500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (2, 2): ref: -10.0000000 res: -10.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (2, 3): ref: 5.0000000 res: 5.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (2, 4): ref: -1.2500000 res: -1.2500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (3, 0): ref: -7.5000000 res: -7.5000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (3, 1): ref: 7.5000000 res: 7.5000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (3, 2): ref: 1.2500000 res: 1.2500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (3, 3): ref: -5.0000000 res: -5.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (3, 4): ref: 10.0000000 res: 10.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 +p2p check: threshold: 0.0000001 + (0, 0): ref: -3.7500000 res: -3.7500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (0, 1): ref: 5.0000000 res: 5.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (0, 2): ref: 13.7500000 res: 13.7500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (0, 3): ref: 1.2500000 res: 1.2500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (0, 4): ref: -11.2500000 res: -11.2500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 0): ref: -2.5000000 res: -2.5000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 1): ref: 6.2500000 res: 6.2500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 2): ref: -5.2500000 res: -5.2500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 3): ref: 2.5000000 res: 2.5000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 4): ref: 11.2500000 res: 11.2500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (2, 0): ref: -1.2500000 res: -1.2500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (2, 1): ref: -13.7500000 res: -13.7500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (2, 2): ref: -5.0000000 res: -5.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (2, 3): ref: 3.7500000 res: 3.7500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (2, 4): ref: -8.7500000 res: -8.7500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (3, 0): ref: 0.0000000 res: 0.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (3, 1): ref: 8.7500000 res: 8.7500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (3, 2): ref: -3.7500000 res: -3.7500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (3, 3): ref: 5.0000000 res: 5.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (3, 4): ref: 13.7500000 res: 13.7500000 abs_diff: 0.0000000 rel_diff: 0.0000000 +PASSED: linalg.add +``` + +``` +# set the arg0 filling follows a distribution N(0, 5) +# set the arg1 filling follows a uniform integer filling [-3, 3] +# use P2P compare strategy on arg2 with threshold = 0 & mistrust rate = 100.0% +# zero threshold will fail the case here + +python3 -m benchgc --verbose 6 --driver linalg --case matmul_transpose_b --md 0:2x5xf32 --md 1:2x5xf32 --md 2:2x2xf32 --fill 0:N:0:5 --fill 1:I:-3:3 --cmp 2:P:0:100 +arg0 shape: [2, 5] dtype: f32 fill_type: N fill_param: ['0', '5'] cmp_type: D cmp_param: ['matmul', 'f32', 'matmul_transpose_b'] +arg1 shape: [2, 5] dtype: f32 fill_type: I fill_param: ['-3', '3'] cmp_type: D cmp_param: ['matmul', 'f32', 'matmul_transpose_b'] +arg2 shape: [2, 2] dtype: f32 fill_type: Z fill_param: [] cmp_type: P cmp_param: ['0', '100'] +module { + func.func @entry(%arg0: tensor<2x5xf32>, %arg1: tensor<2x5xf32>) -> tensor<2x2xf32> attributes {llvm.emit_c_interface} { + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<2x2xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x2xf32>) -> tensor<2x2xf32> + %2 = linalg.matmul_transpose_b {cast = #linalg.type_fn} ins(%arg0, %arg1 : tensor<2x5xf32>, tensor<2x5xf32>) outs(%1 : tensor<2x2xf32>) -> tensor<2x2xf32> + return %2 : tensor<2x2xf32> + } +} + +fill arg0: +tensor([[ 7.7050, -1.4671, -10.8939, 2.8422, -5.4226], + [ -6.9930, 2.0167, 4.1901, -3.5963, -2.0167]]) +fill arg1: +tensor([[-3., 0., 1., 0., 0.], + [ 3., -3., 2., -3., 0.]]) +fill arg2: +tensor([[0., 0.], + [0., 0.]]) +p2p check: threshold: 0.0000010 + (0, 0): ref: 7.7049804 res: 7.7049804 abs_diff: 0.0000000 rel_diff: 0.0000000 + (0, 1): ref: -1.4671445 res: -1.4671445 abs_diff: 0.0000000 rel_diff: 0.0000000 + (0, 2): ref: -10.8939466 res: -10.8939466 abs_diff: 0.0000000 rel_diff: 0.0000000 + (0, 3): ref: 2.8421564 res: 2.8421564 abs_diff: 0.0000000 rel_diff: 0.0000000 + (0, 4): ref: -5.4226117 res: -5.4226117 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 0): ref: -6.9929771 res: -6.9929771 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 1): ref: 2.0167341 res: 2.0167341 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 2): ref: 4.1901317 res: 4.1901317 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 3): ref: -3.5962880 res: -3.5962880 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 4): ref: -2.0167177 res: -2.0167177 abs_diff: 0.0000000 rel_diff: 0.0000000 +p2p check: threshold: 0.0000010 + (0, 0): ref: -3.0000000 res: -3.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (0, 1): ref: 0.0000000 res: 0.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (0, 2): ref: 1.0000000 res: 1.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (0, 3): ref: 0.0000000 res: 0.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (0, 4): ref: 0.0000000 res: 0.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 0): ref: 3.0000000 res: 3.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 1): ref: -3.0000000 res: -3.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 2): ref: 2.0000000 res: 2.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 3): ref: -3.0000000 res: -3.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 4): ref: 0.0000000 res: 0.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 +p2p check: threshold: 0.0000000 + (0, 0): ref: -34.0088882 res: -34.0088882 abs_diff: 0.0000000 rel_diff: 0.0000000 + (0, 1): ref: -2.7979884 res: -2.7979879 abs_diff: 0.0000005 rel_diff: 0.0000002 + (1, 0): ref: 25.1690636 res: 25.1690636 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 1): ref: -7.8600063 res: -7.8600044 abs_diff: 0.0000019 rel_diff: 0.0000002 +FAIL: linalg.matmul_transpose_b +``` \ No newline at end of file diff --git a/test/benchgc/cases/generic.mlir b/test/benchgc/cases/generic.mlir new file mode 100644 index 000000000..3da555777 --- /dev/null +++ b/test/benchgc/cases/generic.mlir @@ -0,0 +1,15 @@ +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +module { + func.func @entry(%arg0: tensor<3x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<3x3xf32> attributes {llvm.emit_c_interface} { + %0 = tensor.empty() : tensor<3x3xf32> + %1 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<3x2xf32>, tensor<2x3xf32>) outs(%0 : tensor<3x3xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %2 = arith.mulf %in, %in_0 : f32 + %3 = arith.addf %out, %2 : f32 + linalg.yield %3 : f32 + } -> tensor<3x3xf32> + return %1 : tensor<3x3xf32> + } +} \ No newline at end of file diff --git a/test/benchgc/cases/llama2.mlir b/test/benchgc/cases/llama2.mlir new file mode 100644 index 000000000..1f557b3e6 --- /dev/null +++ b/test/benchgc/cases/llama2.mlir @@ -0,0 +1,113 @@ +module { + func.func @entry(%arg0: tensor<1x32x4096xbf16>, %arg1: tensor<4096x4096xbf16>, %arg2: tensor<1x32x4096xbf16>, %arg3: tensor<1xf32>, %arg4: tensor<4096xbf16>, %arg5: tensor<11008x4096xbf16>, %arg6: tensor<11008x4096xbf16>, %arg7: tensor<4096x11008xbf16>, %arg8: tensor<1xf32>, %arg9: tensor<4096xbf16>) -> (tensor<1x32x4096xbf16>, tensor<1x32x4096xbf16>) attributes {llvm.emit_c_interface} { + %collapsed = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor<1x32x4096xbf16> into tensor<32x4096xbf16> + %cst = arith.constant 0.000000e+00 : bf16 + %0 = tensor.empty() : tensor<32x4096xbf16> + %1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<32x4096xbf16>) -> tensor<32x4096xbf16> + %2 = linalg.matmul_transpose_b ins(%collapsed, %arg1 : tensor<32x4096xbf16>, tensor<4096x4096xbf16>) outs(%1 : tensor<32x4096xbf16>) -> tensor<32x4096xbf16> + %expanded = tensor.expand_shape %2 [[0, 1], [2]] output_shape [1, 32, 4096] : tensor<32x4096xbf16> into tensor<1x32x4096xbf16> + %3 = tensor.empty() : tensor<1x32x4096xbf16> + %4 = linalg.add ins(%arg2, %expanded : tensor<1x32x4096xbf16>, tensor<1x32x4096xbf16>) outs(%3 : tensor<1x32x4096xbf16>) -> tensor<1x32x4096xbf16> + %5 = tensor.empty() : tensor<1x32x4096xf32> + %6 = linalg.copy ins(%4 : tensor<1x32x4096xbf16>) outs(%5 : tensor<1x32x4096xf32>) -> tensor<1x32x4096xf32> + %cst_0 = arith.constant dense<2.000000e+00> : tensor<1x32x4096xf32> + %7 = tensor.empty() : tensor<1x32x4096xf32> + %8 = linalg.powf ins(%6, %cst_0 : tensor<1x32x4096xf32>, tensor<1x32x4096xf32>) outs(%7 : tensor<1x32x4096xf32>) -> tensor<1x32x4096xf32> + %cst_1 = arith.constant 0.000000e+00 : f32 + %9 = tensor.empty() : tensor<1x32xf32> + %10 = linalg.fill ins(%cst_1 : f32) outs(%9 : tensor<1x32xf32>) -> tensor<1x32xf32> + %reduced = linalg.reduce ins(%8 : tensor<1x32x4096xf32>) outs(%10 : tensor<1x32xf32>) dimensions = [2] + (%in: f32, %init: f32) { + %67 = arith.addf %in, %init : f32 + linalg.yield %67 : f32 + } + %cst_2 = arith.constant dense<4.096000e+03> : tensor<1x32xf32> + %11 = tensor.empty() : tensor<1x32xf32> + %12 = linalg.div ins(%reduced, %cst_2 : tensor<1x32xf32>, tensor<1x32xf32>) outs(%11 : tensor<1x32xf32>) -> tensor<1x32xf32> + %expanded_3 = tensor.expand_shape %12 [[0], [1, 2]] output_shape [1, 32, 1] : tensor<1x32xf32> into tensor<1x32x1xf32> + %13 = tensor.empty() : tensor<1x32x1xf32> + %broadcasted = linalg.broadcast ins(%arg8 : tensor<1xf32>) outs(%13 : tensor<1x32x1xf32>) dimensions = [0, 1] + %14 = tensor.empty() : tensor<1x32x1xf32> + %15 = linalg.add ins(%expanded_3, %broadcasted : tensor<1x32x1xf32>, tensor<1x32x1xf32>) outs(%14 : tensor<1x32x1xf32>) -> tensor<1x32x1xf32> + %cst_4 = arith.constant dense<-5.000000e-01> : tensor<1x32x1xf32> + %16 = tensor.empty() : tensor<1x32x1xf32> + %17 = linalg.powf ins(%15, %cst_4 : tensor<1x32x1xf32>, tensor<1x32x1xf32>) outs(%16 : tensor<1x32x1xf32>) -> tensor<1x32x1xf32> + %collapsed_5 = tensor.collapse_shape %17 [[0], [1, 2]] : tensor<1x32x1xf32> into tensor<1x32xf32> + %18 = tensor.empty() : tensor<1x32x4096xf32> + %broadcasted_6 = linalg.broadcast ins(%collapsed_5 : tensor<1x32xf32>) outs(%18 : tensor<1x32x4096xf32>) dimensions = [2] + %19 = tensor.empty() : tensor<1x32x4096xf32> + %20 = linalg.mul ins(%6, %broadcasted_6 : tensor<1x32x4096xf32>, tensor<1x32x4096xf32>) outs(%19 : tensor<1x32x4096xf32>) -> tensor<1x32x4096xf32> + %21 = tensor.empty() : tensor<1x32x4096xbf16> + %22 = linalg.copy ins(%20 : tensor<1x32x4096xf32>) outs(%21 : tensor<1x32x4096xbf16>) -> tensor<1x32x4096xbf16> + %23 = tensor.empty() : tensor<1x32x4096xbf16> + %broadcasted_7 = linalg.broadcast ins(%arg4 : tensor<4096xbf16>) outs(%23 : tensor<1x32x4096xbf16>) dimensions = [0, 1] + %24 = tensor.empty() : tensor<1x32x4096xbf16> + %25 = linalg.mul ins(%broadcasted_7, %22 : tensor<1x32x4096xbf16>, tensor<1x32x4096xbf16>) outs(%24 : tensor<1x32x4096xbf16>) -> tensor<1x32x4096xbf16> + %collapsed_8 = tensor.collapse_shape %25 [[0, 1], [2]] : tensor<1x32x4096xbf16> into tensor<32x4096xbf16> + %cst_9 = arith.constant 0.000000e+00 : bf16 + %26 = tensor.empty() : tensor<32x11008xbf16> + %27 = linalg.fill ins(%cst_9 : bf16) outs(%26 : tensor<32x11008xbf16>) -> tensor<32x11008xbf16> + %28 = linalg.matmul_transpose_b ins(%collapsed_8, %arg5 : tensor<32x4096xbf16>, tensor<11008x4096xbf16>) outs(%27 : tensor<32x11008xbf16>) -> tensor<32x11008xbf16> + %expanded_10 = tensor.expand_shape %28 [[0, 1], [2]] output_shape [1, 32, 11008] : tensor<32x11008xbf16> into tensor<1x32x11008xbf16> + %29 = tensor.empty() : tensor<1x32x11008xbf16> + %cst_11 = arith.constant dense<1.000000e+00> : tensor<1x32x11008xbf16> + %30 = linalg.negf ins(%expanded_10 : tensor<1x32x11008xbf16>) outs(%29 : tensor<1x32x11008xbf16>) -> tensor<1x32x11008xbf16> + %31 = linalg.exp ins(%30 : tensor<1x32x11008xbf16>) outs(%29 : tensor<1x32x11008xbf16>) -> tensor<1x32x11008xbf16> + %32 = linalg.add ins(%cst_11, %31 : tensor<1x32x11008xbf16>, tensor<1x32x11008xbf16>) outs(%29 : tensor<1x32x11008xbf16>) -> tensor<1x32x11008xbf16> + %33 = linalg.div ins(%cst_11, %32 : tensor<1x32x11008xbf16>, tensor<1x32x11008xbf16>) outs(%29 : tensor<1x32x11008xbf16>) -> tensor<1x32x11008xbf16> + %34 = tensor.empty() : tensor<1x32x11008xbf16> + %35 = linalg.mul ins(%33, %expanded_10 : tensor<1x32x11008xbf16>, tensor<1x32x11008xbf16>) outs(%34 : tensor<1x32x11008xbf16>) -> tensor<1x32x11008xbf16> + %collapsed_12 = tensor.collapse_shape %25 [[0, 1], [2]] : tensor<1x32x4096xbf16> into tensor<32x4096xbf16> + %cst_13 = arith.constant 0.000000e+00 : bf16 + %36 = tensor.empty() : tensor<32x11008xbf16> + %37 = linalg.fill ins(%cst_13 : bf16) outs(%36 : tensor<32x11008xbf16>) -> tensor<32x11008xbf16> + %38 = linalg.matmul_transpose_b ins(%collapsed_12, %arg6 : tensor<32x4096xbf16>, tensor<11008x4096xbf16>) outs(%37 : tensor<32x11008xbf16>) -> tensor<32x11008xbf16> + %expanded_14 = tensor.expand_shape %38 [[0, 1], [2]] output_shape [1, 32, 11008] : tensor<32x11008xbf16> into tensor<1x32x11008xbf16> + %39 = tensor.empty() : tensor<1x32x11008xbf16> + %40 = linalg.mul ins(%35, %expanded_14 : tensor<1x32x11008xbf16>, tensor<1x32x11008xbf16>) outs(%39 : tensor<1x32x11008xbf16>) -> tensor<1x32x11008xbf16> + %collapsed_15 = tensor.collapse_shape %40 [[0, 1], [2]] : tensor<1x32x11008xbf16> into tensor<32x11008xbf16> + %cst_16 = arith.constant 0.000000e+00 : bf16 + %41 = tensor.empty() : tensor<32x4096xbf16> + %42 = linalg.fill ins(%cst_16 : bf16) outs(%41 : tensor<32x4096xbf16>) -> tensor<32x4096xbf16> + %43 = linalg.matmul_transpose_b ins(%collapsed_15, %arg7 : tensor<32x11008xbf16>, tensor<4096x11008xbf16>) outs(%42 : tensor<32x4096xbf16>) -> tensor<32x4096xbf16> + %expanded_17 = tensor.expand_shape %43 [[0, 1], [2]] output_shape [1, 32, 4096] : tensor<32x4096xbf16> into tensor<1x32x4096xbf16> + %44 = tensor.empty() : tensor<1x32x4096xbf16> + %45 = linalg.add ins(%4, %expanded_17 : tensor<1x32x4096xbf16>, tensor<1x32x4096xbf16>) outs(%44 : tensor<1x32x4096xbf16>) -> tensor<1x32x4096xbf16> + %46 = tensor.empty() : tensor<1x32x4096xf32> + %47 = linalg.copy ins(%45 : tensor<1x32x4096xbf16>) outs(%46 : tensor<1x32x4096xf32>) -> tensor<1x32x4096xf32> + %cst_18 = arith.constant dense<2.000000e+00> : tensor<1x32x4096xf32> + %48 = tensor.empty() : tensor<1x32x4096xf32> + %49 = linalg.powf ins(%47, %cst_18 : tensor<1x32x4096xf32>, tensor<1x32x4096xf32>) outs(%48 : tensor<1x32x4096xf32>) -> tensor<1x32x4096xf32> + %cst_19 = arith.constant 0.000000e+00 : f32 + %50 = tensor.empty() : tensor<1x32xf32> + %51 = linalg.fill ins(%cst_19 : f32) outs(%50 : tensor<1x32xf32>) -> tensor<1x32xf32> + %reduced_20 = linalg.reduce ins(%49 : tensor<1x32x4096xf32>) outs(%51 : tensor<1x32xf32>) dimensions = [2] + (%in: f32, %init: f32) { + %67 = arith.addf %in, %init : f32 + linalg.yield %67 : f32 + } + %cst_21 = arith.constant dense<4.096000e+03> : tensor<1x32xf32> + %52 = tensor.empty() : tensor<1x32xf32> + %53 = linalg.div ins(%reduced_20, %cst_21 : tensor<1x32xf32>, tensor<1x32xf32>) outs(%52 : tensor<1x32xf32>) -> tensor<1x32xf32> + %expanded_22 = tensor.expand_shape %53 [[0], [1, 2]] output_shape [1, 32, 1] : tensor<1x32xf32> into tensor<1x32x1xf32> + %54 = tensor.empty() : tensor<1x32x1xf32> + %broadcasted_23 = linalg.broadcast ins(%arg8 : tensor<1xf32>) outs(%54 : tensor<1x32x1xf32>) dimensions = [0, 1] + %55 = tensor.empty() : tensor<1x32x1xf32> + %56 = linalg.add ins(%expanded_22, %broadcasted_23 : tensor<1x32x1xf32>, tensor<1x32x1xf32>) outs(%55 : tensor<1x32x1xf32>) -> tensor<1x32x1xf32> + %cst_24 = arith.constant dense<-5.000000e-01> : tensor<1x32x1xf32> + %57 = tensor.empty() : tensor<1x32x1xf32> + %58 = linalg.powf ins(%56, %cst_24 : tensor<1x32x1xf32>, tensor<1x32x1xf32>) outs(%57 : tensor<1x32x1xf32>) -> tensor<1x32x1xf32> + %collapsed_25 = tensor.collapse_shape %58 [[0], [1, 2]] : tensor<1x32x1xf32> into tensor<1x32xf32> + %59 = tensor.empty() : tensor<1x32x4096xf32> + %broadcasted_26 = linalg.broadcast ins(%collapsed_25 : tensor<1x32xf32>) outs(%59 : tensor<1x32x4096xf32>) dimensions = [2] + %60 = tensor.empty() : tensor<1x32x4096xf32> + %61 = linalg.mul ins(%47, %broadcasted_26 : tensor<1x32x4096xf32>, tensor<1x32x4096xf32>) outs(%60 : tensor<1x32x4096xf32>) -> tensor<1x32x4096xf32> + %62 = tensor.empty() : tensor<1x32x4096xbf16> + %63 = linalg.copy ins(%61 : tensor<1x32x4096xf32>) outs(%62 : tensor<1x32x4096xbf16>) -> tensor<1x32x4096xbf16> + %64 = tensor.empty() : tensor<1x32x4096xbf16> + %broadcasted_27 = linalg.broadcast ins(%arg9 : tensor<4096xbf16>) outs(%64 : tensor<1x32x4096xbf16>) dimensions = [0, 1] + %65 = tensor.empty() : tensor<1x32x4096xbf16> + %66 = linalg.mul ins(%broadcasted_27, %63 : tensor<1x32x4096xbf16>, tensor<1x32x4096xbf16>) outs(%65 : tensor<1x32x4096xbf16>) -> tensor<1x32x4096xbf16> + return %66, %45 : tensor<1x32x4096xbf16>, tensor<1x32x4096xbf16> + } +} \ No newline at end of file diff --git a/test/benchgc/cases/reduce.mlir b/test/benchgc/cases/reduce.mlir new file mode 100644 index 000000000..8183319a9 --- /dev/null +++ b/test/benchgc/cases/reduce.mlir @@ -0,0 +1,12 @@ +module { + func.func @entry(%arg0: tensor<3x5xf32>) -> tensor<3xf32> attributes {llvm.emit_c_interface} { + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<3xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<3xf32>) -> tensor<3xf32> + %reduce = linalg.reduce { arith.addf } + ins(%arg0:tensor<3x5xf32>) + outs(%1:tensor<3xf32>) + dimensions = [1] + return %reduce : tensor<3xf32> + } +} \ No newline at end of file diff --git a/test/benchgc/setup.py b/test/benchgc/setup.py new file mode 100644 index 000000000..3d67af539 --- /dev/null +++ b/test/benchgc/setup.py @@ -0,0 +1,30 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import setuptools + +setuptools.setup( + name="benchgc", + description="benchmark tool for graph compiler", + package_dir={ + "benchgc": "src/benchgc", + "gc_mlir": "../../python_packages/gc_mlir_core/gc_mlir", + }, + packages=setuptools.find_packages("src") + + setuptools.find_namespace_packages("../../python_packages/gc_mlir_core"), + package_data={"gc_mlir": ["_mlir_libs/*.so"]}, + install_requires=["torch", "numpy", "ml_dtypes"], +) diff --git a/test/benchgc/src/benchgc/CMakeLists.txt b/test/benchgc/src/benchgc/CMakeLists.txt new file mode 100644 index 000000000..5700c4f36 --- /dev/null +++ b/test/benchgc/src/benchgc/CMakeLists.txt @@ -0,0 +1,22 @@ +################################################################################ +# Copyright (C) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# SPDX-License-Identifier: Apache-2.0 +################################################################################ + + +file(GLOB PYTHON_SCRIPTS "*.py") +foreach(PY_SCRIPT ${PYTHON_SCRIPTS}) + configure_file(${PY_SCRIPT} ${CMAKE_BINARY_DIR}/test/benchgc/src/benchgc/ COPYONLY) +endforeach() diff --git a/test/benchgc/src/benchgc/__init__.py b/test/benchgc/src/benchgc/__init__.py new file mode 100644 index 000000000..3b87b7051 --- /dev/null +++ b/test/benchgc/src/benchgc/__init__.py @@ -0,0 +1,20 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import pathlib +import sys + +sys.path.append(pathlib.Path(__file__).parent.resolve().__str__()) diff --git a/test/benchgc/src/benchgc/__main__.py b/test/benchgc/src/benchgc/__main__.py new file mode 100644 index 000000000..481cfcd91 --- /dev/null +++ b/test/benchgc/src/benchgc/__main__.py @@ -0,0 +1,274 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + + +import argparse +import sys +from typing import Dict, List + +import benchgc.mlir.util +import benchgc.util +import gc_mlir.ir +import runner +import torch +from benchgc.arg import ( + compare_tensor, + fill_tensor, + set_default_compare, + set_default_fill, +) +from benchgc.arg.arg import Arg +from benchgc.mlir.arg import get_mlir_args +from gc_mlir.graph_compiler import GraphCompiler + +try: + parser = argparse.ArgumentParser(prog="benchmark tool for graph compiler") + parser.add_argument( + "--driver", + required=False, + help="specify the test driver", + choices=["linalg", "tensor", "mlir", "pattern"], + type=str, + ) + parser.add_argument( + "--case", + required=False, + help="test which operation in the specified driver", + type=str, + ) + + parser.add_argument( + "--md", + required=False, + help="format: #ARG:SHAPExTYPE", + type=str, + default=[], + action="append", + ) + parser.add_argument( + "--fill", + required=False, + help="format: #ARG:type:parameter", + type=str, + default=[], + action="append", + ) + parser.add_argument( + "--cmp", + required=False, + help="format: #ARG:type:parameter", + type=str, + default=[], + action="append", + ) + + parser.add_argument( + "--seed", + required=False, + default=0, + type=int, + help="a seed value to generate data filling", + ) + parser.add_argument( + "--verbose", + type=int, + default=benchgc.util.NO_VERBOSE, + help="verbose level", + choices=[ + benchgc.util.NO_VERBOSE, + benchgc.util.MODULE_VERBOSE, + benchgc.util.ARG_VERBOSE, + benchgc.util.COMPARE_VERBOSE, + benchgc.util.ERROR_OUTPUT_VERBOSE, + benchgc.util.OUTPUT_VERBOSE, + benchgc.util.INPUT_VERBOSE, + ], + ) + parser.add_argument( + "--cast", + required=False, + default="cast_signed", + help="define attribute supported by linalg op such as matmul_transpose_b", + choices=["cast_signed", "cast_unsigned"], + type=str, + ) + + # single dimension index + # linalg.softmax + parser.add_argument( + "--dimension", + required=False, + default=None, + help="define the dimension attribute in linalg op", + type=int, + ) + + # multiple dimensions array + # linalg.broadcast / linalg.reduce + parser.add_argument( + "--dimensions", + required=False, + default=None, + action="append", + help="define the dimensions attribute in linalg op", + type=int, + ) + + parser.add_argument( + "--dilations", + required=False, + default=None, + action="append", + help="define the dilations attribute in linalg op", + type=int, + ) + + parser.add_argument( + "--strides", + required=False, + default=None, + action="append", + help="define the strides attribute in linalg op", + type=int, + ) + flags = parser.parse_args() + benchgc.util.set_seed(flags.seed) + + +except argparse.ArgumentError: + sys.stderr.write("Argument parse failed\n") + sys.exit(1) + +args: List[Arg] = [] + +if flags.driver == "mlir": + # we need to find all args by reading the entry function + with open(flags.case, "r") as mlir_file: + with gc_mlir.ir.Context() as ctx: + module = gc_mlir.ir.Module.parse(mlir_file.read()) + entry = benchgc.mlir.util.get_entry(module) + idx: int = 0 + # FIXME: only support RankTensorType now + for i in entry.type.inputs: + args.append(Arg(idx)) + args[-1].dtype = str(i.element_type) + args[-1].shape = list(i.shape) + args[-1].set_scalar() + idx += 1 + + for o in entry.type.results: + args.append(Arg(idx)) + args[-1].dtype = str(o.element_type) + args[-1].shape = list(o.shape) + args[-1].set_scalar() + idx += 1 +elif flags.driver in ["linalg"]: + # all arg shape/dt should be provided in single op test + for i in range(len(flags.md)): + args.append(Arg(i)) + + for md in flags.md: + colon = md.find(":") + if colon == -1: + raise Exception("Wrong md format: %s", md) + idx = int(md[:colon]) + args[idx].set_md(md[colon + 1 :]) + + from .linalg import mlir_op + + mlir_func = mlir_op[flags.case] + module = mlir_func(flags, args) +else: + raise Exception(f"unsupported driver {flags.driver}") + +for fill in flags.fill: + colon = fill.find(":") + if colon == -1: + raise Exception("Wrong fill format: %s", fill) + idx = int(fill[:colon]) + args[idx].set_fill(fill[colon + 1 :]) + +for cmp in flags.cmp: + colon = cmp.find(":") + if colon == -1: + raise Exception("Wrong cmp format: %s", cmp) + idx = int(cmp[:colon]) + args[idx].set_cmp(cmp[colon + 1 :]) + +entry = benchgc.mlir.util.get_entry(module) + +for i, arg in enumerate(args): + # use zero filling if the arg is return value + set_default_fill(flags, arg, args, i >= len(entry.type.inputs)) + set_default_compare(flags, arg, args, i >= len(entry.type.inputs)) + +for arg in args: + arg.print_verbose(flags.verbose) + +if flags.verbose >= benchgc.util.MODULE_VERBOSE: + print(module) + +ref_args: List[torch.Tensor] = [] +gc_args: List[torch.Tensor | int] = [] +ref_tensors: Dict[str, torch.Tensor] = {} +gc_tensors: Dict[str, torch.Tensor] = {} + +for i in range(len(args)): + tensor = fill_tensor(flags, args[i], i) + gc_tensors["%arg" + str(i)] = tensor + ref_tensors["%arg" + str(i)] = tensor.clone() + ref_args.append(ref_tensors["%arg" + str(i)]) + if args[i].scalar: + gc_args.append(tensor.data_ptr()) + else: + gc_args.append(tensor) + + +# ref_out contains return value of the entry +ref_out = runner.ref_run(entry, ref_tensors) + +# we need to swap the result into the args if some arg is the return value +if ref_out is not None: + for i in range(len(ref_out)): + ref_args[0 - i - 1] = ref_out[0 - i - 1] + +entry = "entry" + +mlir_args = get_mlir_args(gc_args) +passes = "any(gc-cpu-pipeline)" + +with module.context: + compiler = GraphCompiler(passes) + engine = compiler.compile_and_jit(module) + engine.invoke(entry, *mlir_args) + +fail, mistrust = False, False +for i in range(len(args)): + # gc_arg contains address for scalar value + # we need to find result by arg name + res = compare_tensor( + args[i], ref_args[i], gc_tensors["%arg" + str(i)], flags.verbose + ) + fail = fail or (not res[0]) + if res[1] is not None: + mistrust = mistrust | res[1] +if fail: + print(f"FAIL: {flags.driver}.{flags.case}") + sys.exit(1) +elif mistrust: + print(f"MISTRUST: {flags.driver}.{flags.case}") +else: + print(f"PASSED: {flags.driver}.{flags.case}") diff --git a/test/benchgc/src/benchgc/arg/CMakeLists.txt b/test/benchgc/src/benchgc/arg/CMakeLists.txt new file mode 100644 index 000000000..614e306da --- /dev/null +++ b/test/benchgc/src/benchgc/arg/CMakeLists.txt @@ -0,0 +1,22 @@ +################################################################################ +# Copyright (C) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# SPDX-License-Identifier: Apache-2.0 +################################################################################ + + +file(GLOB PYTHON_SCRIPTS "*.py") +foreach(PY_SCRIPT ${PYTHON_SCRIPTS}) + configure_file(${PY_SCRIPT} ${CMAKE_BINARY_DIR}/test/benchgc/src/benchgc/arg/ COPYONLY) +endforeach() diff --git a/test/benchgc/src/benchgc/arg/__init__.py b/test/benchgc/src/benchgc/arg/__init__.py new file mode 100644 index 000000000..a2134af2d --- /dev/null +++ b/test/benchgc/src/benchgc/arg/__init__.py @@ -0,0 +1,163 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import argparse +from typing import List, Tuple + +import benchgc.arg.binary as binary +import benchgc.arg.compare +import benchgc.arg.conv as conv +import benchgc.arg.eltwise as eltwise +import benchgc.arg.matmul as matmul +import benchgc.arg.pool as pool +import benchgc.arg.softmax as softmax +import benchgc.util +import torch +from benchgc.arg.arg import Arg + +onednn_module = { + "binary": binary, + "eltwise": eltwise, + "matmul": matmul, + "softmax": softmax, + "conv": conv, + "pool": pool, +} + + +def set_default_fill( + flags: argparse.Namespace, arg: Arg, arglist: List[Arg], is_return: bool +): + if arg.fill_type != "-": + return + + if is_return: + arg.fill_type = "Z" + arg.fill_param = [] + return + + for _, module in onednn_module.items(): + if flags.driver + "." + flags.case in module.op: + module.default_fill(flags, arg, arglist) + return + # use N(0, 1) as default + arg.fill_type = "N" + arg.fill_param = ["0", "1"] + + +def set_default_compare( + flags: argparse.Namespace, arg: Arg, arglist: List[Arg], is_return: bool +): + if arg.cmp_type != "-": + return + + if is_return: + for _, module in onednn_module.items(): + if flags.driver + "." + flags.case in module.op: + module.default_compare(flags, arg, arglist) + return + + dtype: torch.dtype = benchgc.util.get_dtype(arg.dtype) + arg.cmp_type = "P" + if dtype.is_floating_point: + arg.cmp_param = [str(torch.finfo(dtype).eps)] + else: + arg.cmp_param = ["0"] + if is_return: + arg.cmp_param.append("70.0") + else: + arg.cmp_param.append("100.0") + + +def fill_tensor(flags: argparse.Namespace, arg: Arg, idx: int) -> torch.Tensor: + if arg.dtype == "" or arg.fill_type == "": + raise Exception("arg%d filling: dtype/fill_type is not set" % idx) + + # set the seed for the filling + benchgc.util.torch_seed(1, idx) + if arg.fill_type == "N" and len(arg.fill_param) == 2: + # Normal distribution + mean = float(arg.fill_param[0]) + std = float(arg.fill_param[1]) + tensor = torch.normal(mean=mean, std=std, size=arg.shape) + + elif arg.fill_type == "P" and len(arg.fill_param) == 1: + # Poisson distribution + _lambda = float(arg.fill_param[0]) + lambda_tensor = torch.full(arg.shape, _lambda) + tensor = torch.poisson(lambda_tensor) + elif arg.fill_type == "B" and len(arg.fill_param) == 2: + # Binomial distribution + n = int(arg.fill_param[0]) + p = float(arg.fill_param[1]) + bdist = torch.distributions.binomial.Binomial(total_count=n, probs=p) + tensor = bdist.sample(torch.Size(arg.shape)) + elif arg.fill_type == "U" and len(arg.fill_param) == 2: + # Uniform distribution + a = float(arg.fill_param[0]) + b = float(arg.fill_param[1]) + tensor = torch.distributions.uniform.Uniform(a, b).sample(torch.Size(arg.shape)) + elif arg.fill_type == "I" and len(arg.fill_param) == 2: + # integer range + a = int(arg.fill_param[0]) + b = int(arg.fill_param[1]) + tensor = torch.randint(a, b + 1, torch.Size(arg.shape)) + elif arg.fill_type == "F" and len(arg.fill_param) == 1: + # read from pytorch tensor dump file + filename = arg.fill_param[0] + tensor = torch.load(f=filename) + if not isinstance(tensor, torch.Tensor): + raise Exception(f"torch object from file {filename} is not a tensor object") + if tensor.shape != torch.Size(arg.shape): + raise Exception(f"tensor object from file {filename} does not match shape") + if tensor.dtype != benchgc.util.get_dtype(arg.dtype): + raise Exception(f"tensor object from file {filename} does not match dtype") + elif arg.fill_type == "D" and len(arg.fill_param) > 0: + # Driver fill + driver: str = arg.fill_param[0] + driver_module = onednn_module[driver] + tensor = driver_module.fill( + arg.shape, benchgc.util.get_dtype(arg.dtype), arg.fill_param[1:] + ) + elif arg.fill_type == "Z": + tensor = torch.zeros(size=arg.shape) + else: + raise Exception("invalid fill type or fill parameter") + + tensor = tensor.to(benchgc.util.get_dtype(arg.dtype)) + if flags.verbose >= benchgc.util.INPUT_VERBOSE: + print("fill arg%d: " % idx) + print(tensor) + return tensor + + +def compare_tensor( + arg: Arg, ref: torch.Tensor, res: torch.Tensor, verbose: int +) -> Tuple[bool, bool | None]: + + if arg.cmp_type == "P": # p2p check + threshold = float(arg.cmp_param[0]) + zero_percent = float(arg.cmp_param[1]) + return benchgc.arg.compare.p2p(threshold, zero_percent, ref, res, verbose) + if arg.cmp_type == "N": # norm check + threshold = float(arg.cmp_param[0]) + return benchgc.arg.compare.norm(threshold, ref, res, verbose) + elif arg.cmp_type == "D" and len(arg.cmp_param) > 0: # driver check + driver: str = arg.cmp_param[0] + driver_module = onednn_module[driver] + return driver_module.compare(arg.cmp_param[1:], ref, res, verbose) + else: + raise Exception("invalid compare type or compare parameter") diff --git a/test/benchgc/src/benchgc/arg/arg.py b/test/benchgc/src/benchgc/arg/arg.py new file mode 100644 index 000000000..3bf232a93 --- /dev/null +++ b/test/benchgc/src/benchgc/arg/arg.py @@ -0,0 +1,54 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +from typing import List + +import benchgc.mlir.arg +import benchgc.util + + +class Arg(benchgc.mlir.arg.MLIRArg): + fill_type: str + fill_param: List[str] + + cmp_type: str + cmp_param: List[str] + + index: int + + def __init__(self, index: int): + self.dtype = "" + self.fill_type = "-" + self.fill_param = [] + self.cmp_type = "-" + self.cmp_param = [] + self.index = index + + def print_verbose(self, verbose: int): + if verbose >= benchgc.util.ARG_VERBOSE: + print( + f"arg{self.index} shape: {self.shape} dtype: {self.dtype} fill_type: {self.fill_type} fill_param: {self.fill_param} cmp_type: {self.cmp_type} cmp_param: {self.cmp_param}" + ) + + def set_fill(self, fill: str): + splited: List[str] = fill.split(":") + self.fill_type = splited[0] + self.fill_param = splited[1:] + + def set_cmp(self, cmp: str): + splited: List[str] = cmp.split(":") + self.cmp_type = splited[0] + self.cmp_param = splited[1:] diff --git a/test/benchgc/src/benchgc/arg/binary.py b/test/benchgc/src/benchgc/arg/binary.py new file mode 100644 index 000000000..6e2cb0c30 --- /dev/null +++ b/test/benchgc/src/benchgc/arg/binary.py @@ -0,0 +1,98 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import argparse +from typing import Dict, List, Set, Tuple + +import benchgc.util +import torch +from benchgc.arg.arg import Arg +from benchgc.arg.compare import p2p + +# op should use this filling + +op: Set[str] = set( + ["linalg.add", "linalg.div", "linalg.mul", "linalg.max", "linalg.min", "linalg.sub"] +) + +# params format: [src0 | src1, src0 dt, src1 dt, dst dt] + + +def default_fill( + flags: argparse.Namespace, + arg: Arg, + arglist: List[Arg], +): + if arg.index > 1: + raise Exception("binary fill: dst filling is not allowed") + arg.fill_type = "D" + arg.fill_param = [ + "binary", + "src0" if arg.index == 0 else "src1", + arglist[0].dtype, + arglist[1].dtype, + arglist[2].dtype, + ] + + +def fill(shape: List[int], dtype: torch.dtype, params: List[str]) -> torch.Tensor: + name, _, _, _ = params + + accept_name: Dict[str, int] = {"src0": 1, "src1": 2} + if name in accept_name: + arg: int = accept_name[name] + else: + raise Exception("unknown arg name %s", name) + + range_: int = 16 + f_min = 0 if dtype == torch.uint8 else -range_ // 2 + + idx: torch.Tensor = torch.arange( + benchgc.util.nelem(shape), dtype=torch.int + ).reshape(shape) + values: torch.Tensor = (f_min + (12 * idx + 5 * arg + 16) % (range_ + 1)) * 1.25 + if arg == 2: + values = torch.where(values == 0.0, 1, values) + return values.to(dtype=dtype) + + +# compare param: dtype, case + + +def default_compare( + flags: argparse.Namespace, + arg: Arg, + arglist: List[Arg], +): + arg.cmp_type = "D" + arg.cmp_param = ["binary", arg.dtype, flags.case] + + +def compare( + param: List[str], ref: torch.Tensor, res: torch.Tensor, verbose: int +) -> Tuple[bool, bool | None]: + dtype = benchgc.util.get_dtype(param[0]) + + ref = ref.to(torch.float) + res = res.to(torch.float) + + if param[1] in ["div", "div_unsigned"]: + abs_diff = (ref.to(torch.float) - res.to(torch.float)).abs() + init_check = abs_diff < benchgc.util.get_eps(dtype) + else: + init_check = None + + return p2p(benchgc.util.get_eps(dtype), 30.0, ref, res, verbose, init_check) diff --git a/test/benchgc/src/benchgc/arg/compare.py b/test/benchgc/src/benchgc/arg/compare.py new file mode 100644 index 000000000..2e4c31e85 --- /dev/null +++ b/test/benchgc/src/benchgc/arg/compare.py @@ -0,0 +1,137 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +from typing import Callable, List, Tuple + +import benchgc.util +import numpy +import torch + + +def iterate_tensor(tensor: torch.Tensor, fn: Callable[[Tuple[int, ...]], None]): + if tensor.ndim == 0: + fn(tuple()) + return + index: List[int] = [0] * tensor.ndim + + def dfs(depth: int): + if depth == tensor.ndim: + fn(tuple(index)) + else: + for i in range(tensor.shape[depth]): + index[depth] = i + dfs(depth + 1) + + dfs(0) + + +def norm( + threshold: float, ref: torch.Tensor, res: torch.Tensor, verbose: int +) -> Tuple[bool, bool | None]: + + f32_ref = ref.to(torch.float32) + f32_res = res.to(torch.float32) + if f32_ref.nelement() == 0: + return (True, None) + + diff_square_sum = torch.square(torch.subtract(f32_ref, f32_res)).sum() + square_sum = torch.square(f32_ref).sum() + + l2_diff_norm = torch.sqrt(diff_square_sum / square_sum).item() + if verbose >= benchgc.util.COMPARE_VERBOSE: + print(f"norm check: {l2_diff_norm:.10f} / threshold: {threshold:.10f}") + + return (l2_diff_norm < threshold, None) + + +def p2p( + threshold: float, + zero_percent: float, + ref: torch.Tensor, + res: torch.Tensor, + verbose: int, + init_check: torch.Tensor | None = None, +) -> Tuple[bool, bool | None]: + + if verbose >= benchgc.util.COMPARE_VERBOSE: + print(f"p2p check: threshold: {threshold:.7f}") + f32_ref = ref.to(torch.float32) + f32_res = res.to(torch.float32) + + if init_check is None: + check = torch.tensor(False) + else: + check = init_check + + check = check.bitwise_or(torch.bitwise_and(f32_ref.isnan(), f32_res.isnan())) + check = check.bitwise_or(torch.bitwise_and(f32_ref.isneginf(), f32_res.isneginf())) + check = check.bitwise_or(torch.bitwise_and(f32_ref.isposinf(), f32_res.isposinf())) + + # choose diff/rel_diff based on value + abs_diff = (f32_ref - f32_res).abs() + rel_diff = abs_diff / torch.where( + f32_ref.abs() > numpy.finfo(numpy.float32).smallest_subnormal, + f32_ref.abs(), + 1, + ) + # pick a diff for comparison + diff = torch.where(f32_ref.abs() > 1e-5, rel_diff, abs_diff) + + check = check.bitwise_or(diff <= threshold) + + if verbose >= benchgc.util.OUTPUT_VERBOSE: + iterate_tensor( + check, + lambda idx: print( + "%20s: ref: %12.7f res: %12.7f abs_diff: %12.7f rel_diff: %12.7f" + % ( + idx, + f32_ref[idx].item(), + f32_res[idx].item(), + abs_diff[idx].item(), + rel_diff[idx].item(), + ) + ), + ) + if check.all(): + # check mistrusted + zero = res.nelement() - res.count_nonzero().item() + if res.nelement() < 10: + mistrust = False + else: + mistrust = zero * 100.0 / res.nelement() > zero_percent + return (True, mistrust) + else: + if ( + verbose < benchgc.util.OUTPUT_VERBOSE + ): # skip verbose print if full output tensor is alrady printed + fail = torch.argwhere(torch.where(check, 0, 1)) + if verbose < benchgc.util.ERROR_OUTPUT_VERBOSE: + # only print top 10 failed data points if verbose level does not satisfied + fail = fail[:10] + for idx in fail: + index: Tuple[int, ...] = tuple(idx.tolist()) + print( + "%20s: ref: %12.7f res: %12.7f abs_diff: %12.7f rel_diff: %12.7f" + % ( + index, + f32_ref[index].item(), + f32_res[index].item(), + abs_diff[index].item(), + rel_diff[index].item(), + ) + ) + return (False, None) diff --git a/test/benchgc/src/benchgc/arg/conv.py b/test/benchgc/src/benchgc/arg/conv.py new file mode 100644 index 000000000..ba46f201c --- /dev/null +++ b/test/benchgc/src/benchgc/arg/conv.py @@ -0,0 +1,189 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import argparse +from typing import Dict, List, Set, Tuple + +import benchgc.util +import torch +from benchgc.arg.arg import Arg +from benchgc.arg.compare import p2p + +op: Set[str] = set( + [ + "linalg.conv_1d_ncw_fcw", + "linalg.conv_1d_nwc_wcf", + "linalg.conv_1d", + "linalg.conv_2d_nchw_fchw", + "linalg.conv_2d_ngchw_fgchw", + "linalg.conv_2d_ngchw_gfchw", + "linalg.conv_2d_nhwc_fhwc", + "linalg.conv_2d_nhwc_hwcf", + "linalg.conv_2d", + "linalg.conv_3d_ncdhw_fcdhw", + "linalg.conv_3d_ndhwc_dhwcf", + "linalg.conv_3d", + "linalg.depthwise_conv_1d_ncw_cw", + "linalg.depthwise_conv_1d_nwc_wc", + "linalg.depthwise_conv_1d_nwc_wcm", + "linalg.depthwise_conv_2d_nchw_chw", + "linalg.depthwise_conv_2d_nhwc_hwc", + "linalg.depthwise_conv_2d_nhwc_hwcm", + "linalg.depthwise_conv_3d_ncdhw_cdhw", + "linalg.depthwise_conv_3d_ndhwc_dhwc", + "linalg.depthwise_conv_3d_ndhwc_dhwcm", + ] +) + +# params format: [src | wei, src dt, wei dt, dst dt, amp] + + +def default_fill( + flags: argparse.Namespace, + arg: Arg, + arglist: List[Arg], +): + if arg.index > 2: + raise Exception("conv fill: dst filling is not allowed") + arg.fill_type = "D" + arg.fill_param = [ + "conv", + "src" if arg.index == 0 else "wei", + arglist[0].dtype, + arglist[1].dtype, + arglist[2].dtype, + ] + + # find the amplifier of the conv + wei = arglist[1] + nelem = wei.nelem() + if flags.driver == "linalg": + if flags.case in [ + "conv_1d_ncw_fcw", + "conv_2d_nchw_fchw", + "conv_2d_ngchw_fgchw", + "conv_2d_nhwc_fhwc", + "conv_3d_ncdhw_fcdhw", + ]: + arg.fill_param.append(str(nelem // wei.shape[0])) + elif flags.case in ["conv_2d_ngchw_gfchw"]: + arg.fill_param.append(str(nelem // wei.shape[1])) + elif flags.case in [ + "conv_1d_nwc_wcf", + "conv_2d_nhwc_hwcf", + "conv_3d_ndhwc_dhwcf", + "depthwise_conv_1d_nwc_wcm", + "depthwise_conv_2d_nhwc_hwcm", + "depthwise_conv_3d_ndhwc_dhwcm", + ]: + arg.fill_param.append(str(nelem // wei.shape[-1])) + elif flags.case in [ + "conv_1d", + "conv_2d", + "conv_3d", + "depthwise_conv_1d_ncw_cw", + "depthwise_conv_1d_nwc_wc", + "depthwise_conv_2d_nchw_chw", + "depthwise_conv_2d_nhwc_hwc", + "depthwise_conv_3d_ncdhw_cdhw", + "depthwise_conv_3d_ndhwc_dhwc", + ]: + arg.fill_param.append(str(nelem)) + + +def fill(shape: List[int], dtype: torch.dtype, params: List[str]) -> torch.Tensor: + name, src_dt, wei_dt, dst_dt, amp = params + + arg_rng: List[Dict[torch.dtype, Tuple[int, int]]] = [ + { + torch.float32: (-32, 32), + torch.bfloat16: (-4, 4), + torch.float16: (-4, 4), + }, # src + { + torch.float32: (-32, 32), + torch.bfloat16: (-8, 8), + torch.float16: (-2, 2), + }, # wei + ] + + target = torch.empty(size=shape, dtype=torch.float32) + target = target.view(-1) + + src_dt = benchgc.util.get_dtype(src_dt) + wei_dt = benchgc.util.get_dtype(wei_dt) + + src_min, src_max = arg_rng[0][src_dt] + wei_min, wei_max = arg_rng[1][wei_dt] + max_value = max(abs(src_min), abs(src_max)) * max(abs(wei_min), abs(wei_max)) + safe_digits: int = min( + benchgc.util.get_digits("f32"), benchgc.util.get_digits(dst_dt) + ) + safe_n_acc = (1 << safe_digits) // max_value + + if name == "src": + arg_min, arg_max = arg_rng[0][src_dt] + density = 1.0 + elif name == "wei": + arg_min, arg_max = arg_rng[1][wei_dt] + density = min(safe_n_acc / int(amp), 1.0) + else: + raise Exception("unknown arg name %s", name) + + benchgc.util.torch_seed() + + density_t = torch.full(shape, density, dtype=torch.float32) + bernoulli_t = torch.bernoulli(density_t) + condi = density_t == 1 + is_one_t = torch.where(condi, True, bernoulli_t) + gen_value = torch.randint(arg_min, arg_max + 1, size=shape) + target = is_one_t * gen_value + + # make sure the first element is positive + first_val = target.flatten()[0] + if first_val <= 0.0: + while first_val <= 0.0: + first_val = torch.randint(arg_min, arg_max + 1, size=()).item() + target_f = target.view(-1) + target_f[0] = first_val + target = target_f.view(shape) + + return target.to(dtype=dtype) + + +def default_compare( + flags: argparse.Namespace, + arg: Arg, + arglist: List[Arg], +): + arg.cmp_type = "D" + arg.cmp_param = ["conv", arg.dtype, flags.case] + + +def compare( + param: List[str], ref: torch.Tensor, res: torch.Tensor, verbose: int +) -> Tuple[bool, bool | None]: + dtype = benchgc.util.get_dtype(param[0]) + + ref = ref.to(torch.float) + res = res.to(torch.float) + return p2p( + 0.0, # use a relax threshold if using wino + 70.0 if dtype == torch.uint8 else 85.0, + ref, + res, + verbose, + ) diff --git a/test/benchgc/src/benchgc/arg/eltwise.py b/test/benchgc/src/benchgc/arg/eltwise.py new file mode 100644 index 000000000..bf7dbff54 --- /dev/null +++ b/test/benchgc/src/benchgc/arg/eltwise.py @@ -0,0 +1,177 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import argparse +from typing import List, Set, Tuple + +import benchgc.util +import torch +from benchgc.arg.arg import Arg +from benchgc.arg.compare import p2p + +# params format: [alg, alpha, beta] + +# op should use this filling + +op: Set[str] = set( + [ + "linalg.abs", + "linalg.negf", + "linalg.exp", + "linalg.ceil", + "linalg.erf", + "linalg.floor", + "linalg.log", + "linalg.round", + "linalg.rsqrt", + "linalg.sqrt", + "linalg.square", + "linalg.tanh", + ] +) + + +def default_fill( + flags: argparse.Namespace, + arg: Arg, + arglist: List[Arg], +): + if arg.index > 0: + raise Exception("eltwise fill: dst filling is not allowed") + arg.fill_param = ["eltwise", flags.case] + if flags.driver == "linalg" and flags.case in [ + "abs", + "exp", + "ceil", + "erf", + "floor", + "log", + "round", + "sqrt", + "square", + "tanh", + ]: + arg.fill_param.extend(["", ""]) + elif flags.driver == "linalg" and flags.case == "negf": + arg.fill_param.extend(["-1", "0"]) + elif flags.driver == "linalg" and flags.case == "rsqrt": + arg.fill_param.extend(["1", "-0.5"]) + arg.fill_type = "D" + + +def fill(shape: List[int], dtype: torch.dtype, params: List[str]) -> torch.Tensor: + alg, alpha, beta = params + nelems = benchgc.util.nelem(shape) + + float_limit: torch.finfo = torch.finfo(torch.float32) + + alpha = 0.0 if alpha == "" else float(alpha) + beta = 0.0 if beta == "" else float(beta) + + coeff = torch.tensor( + [1, -1, 1, -1, 10.0, -10.0, 10.0, -10.0, 10.0, 10.0, 10.0, 1, 1] + ) + bias = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 88.0, 22.0, 44.0, alpha, beta]) + rand_int_mask = torch.tensor( + [ + True, + True, + False, + False, + True, + True, + False, + False, + False, + False, + False, + False, + False, + ] + ) + rand_uni_mask = torch.tensor( + [ + False, + False, + True, + True, + False, + False, + True, + True, + True, + True, + True, + False, + False, + ] + ) + + if alg == "log": + # append more value for Log validation + coeff = torch.cat((coeff, torch.tensor([1, 1])), dim=0) + bias = torch.cat( + (bias, torch.tensor([float_limit.max, float_limit.min])), dim=0 + ) + rand_int_mask = torch.cat((rand_int_mask, torch.tensor([False, False])), dim=0) + rand_uni_mask = torch.cat((rand_uni_mask, torch.tensor([False, False])), dim=0) + + repeats: int = (nelems + coeff.nelement() - 1) // coeff.nelement() + + coeff = coeff.repeat(repeats)[:nelems] + bias = bias.repeat(repeats)[:nelems] + + rand_int_mask = rand_int_mask.repeat(repeats)[:nelems] + benchgc.util.torch_seed() + rand_int = torch.where(rand_int_mask, torch.randint(0, 10, [nelems]), 0) + + rand_uni_mask = rand_uni_mask.repeat(repeats)[:nelems] + rand_uni = torch.where(rand_uni_mask, torch.rand(nelems) * 0.09, 0) + + value = ((rand_int + rand_uni) * coeff + bias).to(dtype=dtype) + return value.reshape(shape) + + +# param: dtype, case + + +def default_compare( + flags: argparse.Namespace, + arg: Arg, + arglist: List[Arg], +): + arg.cmp_type = "D" + arg.cmp_param = ["eltwise", arg.dtype, flags.case] + + +def compare( + param: List[str], ref: torch.Tensor, res: torch.Tensor, verbose: int +) -> Tuple[bool, bool | None]: + dtype = benchgc.util.get_dtype(param[0]) + ref = ref.to(torch.float) + res = res.to(torch.float) + + threshold = 4e-6 if dtype == torch.float else benchgc.util.get_eps(dtype) + if dtype == torch.float and param[1] in ["tanh", "log"]: + threshold = 4e-5 + + return p2p( + threshold, + 65.0 if dtype.is_floating_point else 100.0, + ref, + res, + verbose, + ) diff --git a/test/benchgc/src/benchgc/arg/matmul.py b/test/benchgc/src/benchgc/arg/matmul.py new file mode 100644 index 000000000..76d0ce81e --- /dev/null +++ b/test/benchgc/src/benchgc/arg/matmul.py @@ -0,0 +1,173 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import argparse +from typing import Dict, List, Set, Tuple + +import benchgc.util +import torch +from benchgc.arg.arg import Arg +from benchgc.arg.compare import p2p + +# params format: [src | wei, src dt, wei dt, dst dt, amp] +# use other filling type for bias + +op: Set[str] = set( + [ + "linalg.batch_matmul", + "linalg.batch_matmul_transpose_a", + "linalg.batch_matmul_transpose_b", + "linalg.batch_matvec", + "linalg.batch_mmt4d", + "linalg.batch_vecmat", + "linalg.batch_reduce_matmul", + "linalg.dot", + "linalg.matmul", + "linalg.matmul_transpose_a", + "linalg.matmul_transpose_b", + "linalg.matvec", + "linalg.mmt4d", + "linalg.vecmat", + ] +) + + +def default_fill( + flags: argparse.Namespace, + arg: Arg, + arglist: List[Arg], +): + if arg.index > 1: + raise Exception("matmul fill: dst filling is not allowed") + arg.fill_type = "D" + arg.fill_param = [ + "matmul", + "src" if arg.index == 0 else "wei", + arglist[0].dtype, + arglist[1].dtype, + arglist[2].dtype, + ] + + # find the amplifier K of the matmul + if flags.driver == "linalg": + if ( + flags.case == "matmul_transpose_b" + or flags.case == "batch_matmul" + and arg.index == 0 + or flags.case == "batch_matmul_transpose_b" + or flags.case == "batch_matvec" + or flags.case == "batch_vecmat" + and arg.index == 0 + or flags.case == "matmul" + and arg.index == 0 + or flags.case == "matvec" + or flags.case == "vecmat" + and arg.index == 0 + or flags.case == "dot" + ): + arg.fill_param.append(str(arg.shape[-1])) + elif ( + flags.case == "batch_matmul" + and arg.index == 1 + or flags.case == "batch_matmul_transpose_a" + or flags.case == "batch_vecmat" + and arg.index == 1 + or flags.case == "matmul" + and arg.index == 1 + or flags.case == "matmul_transpose_a" + or flags.case == "vecmat" + and arg.index == 1 + ): + arg.fill_param.append(str(arg.shape[-2])) + elif flags.case == "batch_mmt4d" or flags.case == "mmt4d": + arg.fill_param.append(str(arg.shape[-1] * arg.shape[-3])) + # reduce the matmul will amplified by B * K + elif flags.case == "batch_reduce_matmul" and arg.index == 0: + arg.fill_param.append(str(arg.shape[-1] * arg.shape[0])) + elif flags.case == "batch_reduce_matmul" and arg.index == 1: + arg.fill_param.append(str(arg.shape[-2] * arg.shape[0])) + + +def fill(shape: List[int], dtype: torch.dtype, params: List[str]) -> torch.Tensor: + name, src_dt, wei_dt, dst_dt, amp = params + + arg_rng: List[Dict[torch.dtype, Tuple[int, int]]] = [ + { + torch.float32: (-64, 64), + torch.bfloat16: (-4, 4), + torch.float16: (-4, 4), + }, # src + { + torch.float32: (-128, 128), + torch.bfloat16: (-8, 8), + torch.float16: (-2, 2), + }, # wei + ] + + src_dt = benchgc.util.get_dtype(src_dt) + wei_dt = benchgc.util.get_dtype(wei_dt) + + src_min, src_max = arg_rng[0][src_dt] + wei_min, wei_max = arg_rng[1][wei_dt] + max_value = max(abs(src_min), abs(src_max)) * max(abs(wei_min), abs(wei_max)) + safe_digits: int = min( + benchgc.util.get_digits("f32"), benchgc.util.get_digits(dst_dt) + ) + + safe_n_acc = (1 << safe_digits) // max_value + + if name == "src": + arg_min, arg_max = arg_rng[0][src_dt] + density = 1.0 + elif name == "wei": + arg_min, arg_max = arg_rng[1][wei_dt] + density = min(safe_n_acc / int(amp), 1.0) + else: + raise Exception("unknown arg name %s", name) + + benchgc.util.torch_seed(1, 0 if name == "src" else 1) + value = torch.bernoulli(torch.full(shape, density)) * torch.randint( + arg_min, arg_max, shape + ) + while value.flatten()[0] <= 0: + value.flatten()[0] = torch.randint(arg_min, arg_max + 1, size=[1])[0].item() + + return value.to(dtype) + + +def default_compare( + flags: argparse.Namespace, + arg: Arg, + arglist: List[Arg], +): + arg.cmp_type = "D" + arg.cmp_param = ["matmul", arg.dtype, flags.case] + + +def compare( + param: List[str], ref: torch.Tensor, res: torch.Tensor, verbose: int +) -> Tuple[bool, bool | None]: + dtype = benchgc.util.get_dtype(param[0]) + + ref = ref.to(torch.float) + res = res.to(torch.float) + return p2p( + 1e-6 if dtype == torch.float else benchgc.util.get_eps(dtype), + 90.0, + ref, + res, + verbose, + ) diff --git a/test/benchgc/src/benchgc/arg/pool.py b/test/benchgc/src/benchgc/arg/pool.py new file mode 100644 index 000000000..7179b7f91 --- /dev/null +++ b/test/benchgc/src/benchgc/arg/pool.py @@ -0,0 +1,97 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import argparse +from typing import List, Set, Tuple + +import benchgc.util +import torch +from benchgc.arg.arg import Arg +from benchgc.arg.compare import p2p + +op: Set[str] = set( + [ + "linalg.pooling_nchw_max", + "linalg.pooling_nchw_sum", + "linalg.pooling_ncw_max", + "linalg.pooling_ncw_sum", + "linalg.pooling_ndhwc_max", + "linalg.pooling_ndhwc_sum", + "linalg.pooling_nhwc_max", + "linalg.pooling_nhwc_sum", + "linalg.pooling_nhwc_min", + "linalg.pooling_nwc_max", + "linalg.pooling_nwc_min", + "linalg.pooling_nwc_sum", + ] +) + + +def default_fill( + flags: argparse.Namespace, + arg: Arg, + arglist: List[Arg], +): + if arg.index > 1: + raise Exception("pool fill: dst filling is not allowed") + arg.fill_type = "D" + arg.fill_param = ["pool"] + + +def fill(shape: List[int], dtype: torch.dtype, params: List[str]) -> torch.Tensor: + arg_rng: Tuple[int, int] = { + torch.float64: (-2048, 2048), + torch.float32: (-2048, 2048), + torch.int32: (-2048, 2048), + torch.bfloat16: (-32, 32), + torch.float16: (-32, 32), + torch.int8: (-128, 127), + torch.uint8: (0, 255), + }[dtype] + + benchgc.util.torch_seed() + target = torch.randint(arg_rng[0], arg_rng[1] + 1, size=[benchgc.util.nelem(shape)]) + # make sure the first element is not negative + if target[0] <= 0.0: + while target[0] <= 0.0: + target[0] = torch.randint(arg_rng[0], arg_rng[1], size=(1,))[0].item() + + return target.reshape(shape).to(dtype=dtype) + + +def default_compare( + flags: argparse.Namespace, + arg: Arg, + arglist: List[Arg], +): + arg.cmp_type = "D" + arg.cmp_param = ["pool", arg.dtype, flags.case] + + +def compare( + param: List[str], ref: torch.Tensor, res: torch.Tensor, verbose: int +) -> Tuple[bool, bool | None]: + dtype = benchgc.util.get_dtype(param[0]) + + ref = ref.to(torch.float) + res = res.to(torch.float) + return p2p( + benchgc.util.get_eps(dtype) * 10.0, + 99.0, + ref, + res, + verbose, + ) diff --git a/test/benchgc/src/benchgc/arg/reduce.py b/test/benchgc/src/benchgc/arg/reduce.py new file mode 100644 index 000000000..bc75e5d84 --- /dev/null +++ b/test/benchgc/src/benchgc/arg/reduce.py @@ -0,0 +1,78 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +from typing import List, Tuple + +import benchgc.arg +import benchgc.util +import torch + + +def fill(shape: List[int], dtype: torch.dtype, params: List[str]) -> torch.Tensor: + + op, sdtype, ddtype, amp = params + + sdtype = benchgc.util.get_dtype(sdtype) + ddtype = benchgc.util.get_dtype(ddtype) + + safe_to_reduce_elems: int = benchgc.util.get_problem_bounds(op, sdtype)[0] + + neutral_value: float = 1.0 if op == "mul" else 0.0 + + shift: float = ( + 1.0 + if ( + op == "mean" + or op == "min" + and not sdtype.is_signed + and not ddtype.is_signed + ) + else 0.0 + ) + + value_range: int = benchgc.util.get_problem_bounds(op, sdtype)[1] + + is_mul_fp: bool = op == "mul" and sdtype.is_floating_point + min_range: int = -value_range if is_mul_fp else 1 + + index = torch.arange(benchgc.util.nelem(shape)).reshape(shape) + + benchgc.util.torch_seed() + value = torch.randint(min_range, value_range + 1, size=shape) + if is_mul_fp: + value = torch.pow(2, value) + if sdtype.is_signed: # random choose positive or negative + value = torch.where(torch.BoolTensor(size=shape), value, -value) + + non_neutral_mask = benchgc.util.flip_coin( + index, + torch.full(shape, safe_to_reduce_elems / int(amp), dtype=torch.float32), + ) + if isinstance(non_neutral_mask, torch.Tensor): + value = torch.where(non_neutral_mask, value, neutral_value) + else: + raise Exception("Flip coin failed when generate the reduce data filling") + value = value + shift + return value.to(dtype) + + +def compare( + ref: torch.Tensor, res: torch.Tensor, verbose: int +) -> Tuple[bool, bool | None]: + dtype = ref.dtype + ref = ref.to(torch.float) + res = res.to(torch.float) + return benchgc.arg.p2p(benchgc.util.get_eps(dtype), 30.0, ref, res, verbose) diff --git a/test/benchgc/src/benchgc/arg/softmax.py b/test/benchgc/src/benchgc/arg/softmax.py new file mode 100644 index 000000000..a9731ec0a --- /dev/null +++ b/test/benchgc/src/benchgc/arg/softmax.py @@ -0,0 +1,93 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import argparse +import operator +from functools import reduce +from typing import List, Set, Tuple + +import benchgc.util +import torch +from benchgc.arg.arg import Arg +from benchgc.arg.compare import p2p + +op: Set[str] = set(["linalg.softmax"]) + + +# params format: [reduce dimension] + + +def default_fill( + flags: argparse.Namespace, + arg: Arg, + arglist: List[Arg], +): + if arg.index > 0: + raise Exception("softmax fill: dst filling is not allowed") + arg.fill_type = "D" + arg.fill_param = ["softmax", str(flags.dimension)] + + +def fill(shape: List[int], dtype: torch.dtype, params: List[str]) -> torch.Tensor: + dimension: int = int(params[0]) + + outer: int = reduce(operator.mul, shape[:dimension], 1) + inner: int = reduce(operator.mul, shape[dimension + 1 :], 1) + benchgc.util.torch_seed() + sign = torch.randint(0, 1, size=[1, shape[dimension], 1]) * 2 - 1 + value = torch.randint(87, 90, size=[outer, shape[dimension], inner]) + value = torch.where(value == 87, 0, value) + value = value * sign + value = torch.where(value == 0, torch.finfo(dtype).min, value) + return value.reshape(shape).to(dtype) + + +# param: dtype, case, reduce size +def default_compare( + flags: argparse.Namespace, + arg: Arg, + arglist: List[Arg], +): + arg.cmp_type = "D" + arg.cmp_param = [ + "softmax", + arg.dtype, + flags.case, + str(arg.shape[int(flags.dimension)]), + ] + + +def compare( + param: List[str], ref: torch.Tensor, res: torch.Tensor, verbose: int +) -> Tuple[bool, bool | None]: + dtype = benchgc.util.get_dtype(param[0]) + ref = ref.to(torch.float) + res = res.to(torch.float) + + reduce_size = int(param[2]) + nzeros = ( + reduce_size - 1 + if dtype == torch.int8 or dtype == torch.uint8 + else max(0, reduce_size - 8) + ) + + return p2p( + benchgc.util.get_eps(dtype) * (5.0 if dtype == torch.float else 1.0), + 100.0 * nzeros / reduce_size, + ref, + res, + verbose, + ) diff --git a/test/benchgc/src/benchgc/arith/CMakeLists.txt b/test/benchgc/src/benchgc/arith/CMakeLists.txt new file mode 100644 index 000000000..63d2bfa79 --- /dev/null +++ b/test/benchgc/src/benchgc/arith/CMakeLists.txt @@ -0,0 +1,22 @@ +################################################################################ +# Copyright (C) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# SPDX-License-Identifier: Apache-2.0 +################################################################################ + + +file(GLOB PYTHON_SCRIPTS "*.py") +foreach(PY_SCRIPT ${PYTHON_SCRIPTS}) + configure_file(${PY_SCRIPT} ${CMAKE_BINARY_DIR}/test/benchgc/src/benchgc/arith COPYONLY) +endforeach() diff --git a/test/benchgc/src/benchgc/arith/__init__.py b/test/benchgc/src/benchgc/arith/__init__.py new file mode 100644 index 000000000..a5f942a72 --- /dev/null +++ b/test/benchgc/src/benchgc/arith/__init__.py @@ -0,0 +1,45 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import argparse +import importlib +from typing import Callable, Dict, List, Tuple + +import gc_mlir.ir +import torch +from benchgc.arg import Arg +from benchgc.mlir.util import MLIRCache + +ref_op: Dict[ + str, + Callable[ + [MLIRCache, gc_mlir.ir.OpView, Dict[str, torch.Tensor]], + Tuple[torch.Tensor, ...], + ], +] = {} +mlir_op: Dict[ + str, Callable[[argparse.Namespace, List[Arg], List[Arg]], gc_mlir.ir.Module] +] = {} + +for dri in ["basic"]: + mod = importlib.import_module(f"benchgc.arith.{dri}") + for key in mod.__dict__: + if key.startswith("ref_"): + op: str = key.removeprefix("ref_") + ref_op[op] = mod.__dict__[key] + if key.startswith("mlir_"): + op: str = key.removeprefix("mlir_") + mlir_op[op] = mod.__dict__[key] diff --git a/test/benchgc/src/benchgc/arith/basic.py b/test/benchgc/src/benchgc/arith/basic.py new file mode 100644 index 000000000..7e4b17467 --- /dev/null +++ b/test/benchgc/src/benchgc/arith/basic.py @@ -0,0 +1,58 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +from typing import Dict, Tuple + +import benchgc.util +import gc_mlir._mlir_libs._mlir.ir +import gc_mlir.ir +import torch +from benchgc.mlir.util import MLIRCache + + +def ref_constant( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + value = op.attributes["value"] + if isinstance(value, gc_mlir._mlir_libs._mlir.ir.FloatAttr): + return ( + torch.full(size=tuple(), fill_value=value.__float__(), dtype=torch.float), + ) + elif isinstance(value, gc_mlir._mlir_libs._mlir.ir.DenseFPElementsAttr): + if value.is_splat: + return ( + torch.full( + size=tuple(value.type.shape), + fill_value=value.get_splat_value().value, + dtype=benchgc.util.get_dtype(str(value.get_splat_value().type)), + ), + ) + else: + raise Exception("only support splat value now") + else: + raise Exception("Not support constant type %s", type(value)) + + +def ref_mulf( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (var[cache.opr[0]] * var[cache.opr[1]],) + + +def ref_addf( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (var[cache.opr[0]] + var[cache.opr[1]],) diff --git a/test/benchgc/src/benchgc/linalg/CMakeLists.txt b/test/benchgc/src/benchgc/linalg/CMakeLists.txt new file mode 100644 index 000000000..8daf7848a --- /dev/null +++ b/test/benchgc/src/benchgc/linalg/CMakeLists.txt @@ -0,0 +1,22 @@ +################################################################################ +# Copyright (C) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# SPDX-License-Identifier: Apache-2.0 +################################################################################ + + +file(GLOB PYTHON_SCRIPTS "*.py") +foreach(PY_SCRIPT ${PYTHON_SCRIPTS}) + configure_file(${PY_SCRIPT} ${CMAKE_BINARY_DIR}/test/benchgc/src/benchgc/linalg/ COPYONLY) +endforeach() diff --git a/test/benchgc/src/benchgc/linalg/__init__.py b/test/benchgc/src/benchgc/linalg/__init__.py new file mode 100644 index 000000000..331bd75dd --- /dev/null +++ b/test/benchgc/src/benchgc/linalg/__init__.py @@ -0,0 +1,52 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import argparse +import importlib +from typing import Callable, Dict, List, Tuple + +import gc_mlir.ir +import torch +from benchgc.arg import Arg +from benchgc.mlir.util import MLIRCache + +ref_op: Dict[ + str, + Callable[ + [MLIRCache, gc_mlir.ir.OpView, Dict[str, torch.Tensor]], + Tuple[torch.Tensor, ...], + ], +] = {} +mlir_op: Dict[str, Callable[[argparse.Namespace, List[Arg]], gc_mlir.ir.Module]] = {} + +for dri in [ + "binary", + "matmul", + "eltwise", + "misc", + "generic", + "softmax", + "conv", + "pool", +]: + mod = importlib.import_module(f"benchgc.linalg.{dri}") + for key in mod.__dict__: + if key.startswith("ref_"): + op: str = key.removeprefix("ref_") + ref_op[op] = mod.__dict__[key] + if key.startswith("mlir_"): + op: str = key.removeprefix("mlir_") + mlir_op[op] = mod.__dict__[key] diff --git a/test/benchgc/src/benchgc/linalg/binary.py b/test/benchgc/src/benchgc/linalg/binary.py new file mode 100644 index 000000000..ed5d280a3 --- /dev/null +++ b/test/benchgc/src/benchgc/linalg/binary.py @@ -0,0 +1,137 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import argparse +from typing import Dict, List, Tuple + +import gc_mlir.ir +import torch +from benchgc.arg import Arg +from benchgc.mlir.module import init_module +from benchgc.mlir.util import MLIRCache +from gc_mlir.dialects import linalg + + +def ref_add( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.add(var[cache.opr[0]], var[cache.opr[1]]),) + + +def mlir_add(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.add(arg0, arg1, outs=[args[2].get_zero_op(ctx)]) + ], + ) + + +def ref_powf( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.pow(var[cache.opr[0]], var[cache.opr[1]]),) + + +def mlir_powf(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.powf(arg0, arg1, outs=[args[2].get_zero_op(ctx)]) + ], + ) + + +def ref_div( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.div(var[cache.opr[0]], var[cache.opr[1]]),) + + +def mlir_div(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.div(arg0, arg1, outs=[args[2].get_zero_op(ctx)]) + ], + ) + + +def ref_max( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.max(var[cache.opr[0]], var[cache.opr[1]]),) + + +def mlir_max(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.max(arg0, arg1, outs=[args[2].get_zero_op(ctx)]) + ], + ) + + +def ref_min( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.min(var[cache.opr[0]], var[cache.opr[1]]),) + + +def mlir_min(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.min(arg0, arg1, outs=[args[2].get_zero_op(ctx)]) + ], + ) + + +def ref_mul( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.mul(var[cache.opr[0]], var[cache.opr[1]]),) + + +def mlir_mul(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.mul(arg0, arg1, outs=[args[2].get_zero_op(ctx)]) + ], + ) + + +def ref_sub( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.sub(var[cache.opr[0]], var[cache.opr[1]]),) + + +def mlir_sub(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.sub(arg0, arg1, outs=[args[2].get_zero_op(ctx)]) + ], + ) diff --git a/test/benchgc/src/benchgc/linalg/conv.py b/test/benchgc/src/benchgc/linalg/conv.py new file mode 100644 index 000000000..c8fc38efb --- /dev/null +++ b/test/benchgc/src/benchgc/linalg/conv.py @@ -0,0 +1,834 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import argparse +from typing import Dict, List, Tuple + +import gc_mlir.ir +import torch +from benchgc.arg import Arg +from benchgc.mlir.module import init_module +from benchgc.mlir.util import MLIRCache +from gc_mlir.dialects import linalg + + +def ref_conv_1d_ncw_fcw( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + return ( + torch.conv1d( + var[cache.opr[0]], + var[cache.opr[1]], + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + ), + ) + + +def mlir_conv_1d_ncw_fcw( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.conv_1d_ncw_fcw( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_conv_1d_nwc_wcf( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + + # src: nwc -> ncw + # wei: wcf -> fcw + # dst: nwf -> nfw + + return ( + torch.conv1d( + var[cache.opr[0]].permute([0, 2, 1]), + var[cache.opr[1]].permute([2, 1, 0]), + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + ) + .permute([0, 2, 1]) + .contiguous(), + ) + + +def mlir_conv_1d_nwc_wcf( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.conv_1d_nwc_wcf( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_conv_1d_ncw_fcw( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + return ( + torch.conv1d( + var[cache.opr[0]], + var[cache.opr[1]], + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + ), + ) + + +def mlir_conv_1d_ncw_fcw( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.conv_1d_ncw_fcw( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_conv_1d( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return ( + torch.conv1d( + var[cache.opr[0]].unsqueeze(0).unsqueeze(0), + var[cache.opr[1]].unsqueeze(0).unsqueeze(0), + ) + .squeeze(0) + .squeeze(0), + ) + + +def mlir_conv_1d(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.conv_1d( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + ) + ], + ) + + +def ref_conv_2d_nchw_fchw( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + return ( + torch.conv2d( + var[cache.opr[0]], + var[cache.opr[1]], + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + ), + ) + + +def mlir_conv_2d_nchw_fchw( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.conv_2d_nchw_fchw( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_conv_2d_ngchw_fgchw( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + + src = var[cache.opr[0]] + wei = var[cache.opr[1]] + groups: int = src.shape[1] + + dst = torch.conv2d( + src.reshape( + [src.shape[0], src.shape[1] * src.shape[2], src.shape[3], src.shape[4]] + ), # merge group axis with channel + wei.transpose(0, 1) + .contiguous() + .reshape( + [wei.shape[0] * wei.shape[1], wei.shape[2], wei.shape[3], wei.shape[4]] + ), # merge group axis with output channel + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + groups=groups, + ) + return ( + dst.reshape( + [dst.shape[0], groups, dst.shape[1] // groups, dst.shape[2], dst.shape[3]] + ), + ) # split group axis from output channel + + +def mlir_conv_2d_ngchw_fgchw( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.conv_2d_ngchw_fgchw( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_conv_2d_ngchw_gfchw( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + + src = var[cache.opr[0]] + wei = var[cache.opr[1]] + groups: int = src.shape[1] + + dst = torch.conv2d( + src.reshape( + [src.shape[0], src.shape[1] * src.shape[2], src.shape[3], src.shape[4]] + ), # merge group axis with channel + wei.reshape( + [wei.shape[0] * wei.shape[1], wei.shape[2], wei.shape[3], wei.shape[4]] + ), # merge group axis with output channel + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + groups=groups, + ) + return ( + dst.reshape( + [dst.shape[0], groups, dst.shape[1] // groups, dst.shape[2], dst.shape[3]] + ), + ) # split group axis from output channel + + +def mlir_conv_2d_ngchw_gfchw( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.conv_2d_ngchw_gfchw( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_conv_2d_nhwc_fhwc( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + return ( + torch.conv2d( + var[cache.opr[0]].permute([0, 3, 1, 2]), + var[cache.opr[1]].permute([0, 3, 1, 2]), + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + ) + .permute([0, 2, 3, 1]) + .contiguous(), + ) + + +def mlir_conv_2d_nhwc_fhwc( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.conv_2d_nhwc_fhwc( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_conv_2d_nhwc_hwcf( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + return ( + torch.conv2d( + var[cache.opr[0]].permute([0, 3, 1, 2]), + var[cache.opr[1]].permute([3, 2, 0, 1]), + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + ) + .permute([0, 2, 3, 1]) + .contiguous(), + ) + + +def mlir_conv_2d_nhwc_hwcf( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.conv_2d_nhwc_hwcf( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_conv_2d( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return ( + torch.conv2d( + var[cache.opr[0]].unsqueeze(0).unsqueeze(0), + var[cache.opr[1]].unsqueeze(0).unsqueeze(0), + ) + .squeeze(0) + .squeeze(0), + ) + + +def mlir_conv_2d(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.conv_2d( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + ) + ], + ) + + +def ref_conv_3d_ncdhw_fcdhw( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + return ( + torch.conv3d( + var[cache.opr[0]], + var[cache.opr[1]], + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + ), + ) + + +def mlir_conv_3d_ncdhw_fcdhw( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.conv_3d_ncdhw_fcdhw( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_conv_3d_ndhwc_dhwcf( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + return ( + torch.conv3d( + var[cache.opr[0]].permute([0, 4, 1, 2, 3]), + var[cache.opr[1]].permute([4, 3, 0, 1, 2]), + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + ) + .permute([0, 2, 3, 4, 1]) + .contiguous(), + ) + + +def mlir_conv_3d_ndhwc_dhwcf( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.conv_3d_ndhwc_dhwcf( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_conv_3d( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return ( + torch.conv3d( + var[cache.opr[0]].unsqueeze(0).unsqueeze(0), + var[cache.opr[1]].unsqueeze(0).unsqueeze(0), + ) + .squeeze(0) + .squeeze(0), + ) + + +def mlir_conv_3d(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.conv_3d( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + ) + ], + ) + + +def ref_depthwise_conv_1d_ncw_cw( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + groups: int = var[cache.opr[0]].shape[1] + return ( + torch.conv1d( + var[cache.opr[0]], + var[cache.opr[1]].unsqueeze(1), + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + groups=groups, + ), + ) + + +def mlir_depthwise_conv_1d_ncw_cw( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.depthwise_conv_1d_ncw_cw( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_depthwise_conv_1d_nwc_wc( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + groups: int = var[cache.opr[0]].shape[-1] + return ( + torch.conv1d( + var[cache.opr[0]].transpose(-1, -2), + var[cache.opr[1]].transpose(-1, -2).unsqueeze(1), + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + groups=groups, + ) + .transpose(-1, -2) + .contiguous(), + ) + + +def mlir_depthwise_conv_1d_nwc_wc( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.depthwise_conv_1d_nwc_wc( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_depthwise_conv_1d_nwc_wcm( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + src = var[cache.opr[0]] + groups: int = src.shape[-1] + wei = var[cache.opr[1]] + dst = ( + torch.conv1d( + src.transpose(-1, -2), + wei.reshape([wei.shape[0], wei.shape[1] * wei.shape[2]]) + .transpose(-1, -2) + .unsqueeze(1), + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + groups=groups, + ) + .transpose(-1, -2) + .contiguous() + ) + return (dst.reshape([dst.shape[0], dst.shape[1], wei.shape[1], wei.shape[2]]),) + + +def mlir_depthwise_conv_1d_nwc_wcm( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.depthwise_conv_1d_nwc_wcm( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_depthwise_conv_2d_nchw_chw( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + groups: int = var[cache.opr[0]].shape[1] + return ( + torch.conv2d( + var[cache.opr[0]], + var[cache.opr[1]].unsqueeze(1), + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + groups=groups, + ), + ) + + +def mlir_depthwise_conv_2d_nchw_chw( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.depthwise_conv_2d_nchw_chw( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_depthwise_conv_2d_nhwc_hwc( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + groups: int = var[cache.opr[0]].shape[-1] + return ( + torch.conv2d( + var[cache.opr[0]].permute([0, 3, 1, 2]), + var[cache.opr[1]].permute([2, 0, 1]).unsqueeze(1), + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + groups=groups, + ) + .permute([0, 2, 3, 1]) + .contiguous(), + ) + + +def mlir_depthwise_conv_2d_nhwc_hwc( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.depthwise_conv_2d_nhwc_hwc( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_depthwise_conv_2d_nhwc_hwcm( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + groups: int = var[cache.opr[0]].shape[-1] + wei = var[cache.opr[1]] + dst = ( + torch.conv2d( + var[cache.opr[0]].permute([0, 3, 1, 2]), + wei.reshape([wei.shape[0], wei.shape[1], wei.shape[2] * wei.shape[3]]) + .permute([2, 0, 1]) + .unsqueeze(1), + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + groups=groups, + ) + .permute([0, 2, 3, 1]) + .contiguous() + ) + return ( + dst.reshape( + [dst.shape[0], dst.shape[1], dst.shape[2], wei.shape[-2], wei.shape[-1]] + ), + ) + + +def mlir_depthwise_conv_2d_nhwc_hwcm( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.depthwise_conv_2d_nhwc_hwcm( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_depthwise_conv_3d_ncdhw_cdhw( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + groups: int = var[cache.opr[0]].shape[1] + return ( + torch.conv3d( + var[cache.opr[0]], + var[cache.opr[1]].unsqueeze(1), + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + groups=groups, + ), + ) + + +def mlir_depthwise_conv_3d_ncdhw_cdhw( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.depthwise_conv_3d_ncdhw_cdhw( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_depthwise_conv_3d_ndhwc_dhwc( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + groups: int = var[cache.opr[0]].shape[-1] + return ( + torch.conv3d( + var[cache.opr[0]].permute([0, 4, 1, 2, 3]), + var[cache.opr[1]].permute([3, 0, 1, 2]).unsqueeze(1), + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + groups=groups, + ) + .permute([0, 2, 3, 4, 1]) + .contiguous(), + ) + + +def mlir_depthwise_conv_3d_ndhwc_dhwc( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.depthwise_conv_3d_ndhwc_dhwc( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_depthwise_conv_3d_ndhwc_dhwcm( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + groups: int = var[cache.opr[0]].shape[-1] + wei = var[cache.opr[1]] + dst = ( + torch.conv3d( + var[cache.opr[0]].permute([0, 4, 1, 2, 3]), + wei.reshape( + [wei.shape[0], wei.shape[1], wei.shape[2], wei.shape[3] * wei.shape[4]] + ) + .permute([3, 0, 1, 2]) + .unsqueeze(1), + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + groups=groups, + ) + .permute([0, 2, 3, 4, 1]) + .contiguous() + ) + return ( + dst.reshape( + [ + dst.shape[0], + dst.shape[1], + dst.shape[2], + dst.shape[3], + wei.shape[-2], + wei.shape[-1], + ] + ), + ) + + +def mlir_depthwise_conv_3d_ndhwc_dhwcm( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.depthwise_conv_3d_ndhwc_dhwcm( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) diff --git a/test/benchgc/src/benchgc/linalg/eltwise.py b/test/benchgc/src/benchgc/linalg/eltwise.py new file mode 100644 index 000000000..7ae9b31b7 --- /dev/null +++ b/test/benchgc/src/benchgc/linalg/eltwise.py @@ -0,0 +1,197 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import argparse +from typing import Dict, List, Tuple + +import gc_mlir.ir +import torch +from benchgc.arg import Arg +from benchgc.mlir.module import init_module +from benchgc.mlir.util import MLIRCache +from gc_mlir.dialects import linalg + + +def ref_abs( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.abs(var[cache.opr[0]]),) + + +def mlir_abs(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0],), + (args[1],), + lambda ctx, arg0: [linalg.abs(arg0, outs=[args[1].get_zero_op(ctx)])], + ) + + +def ref_ceil( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.ceil(var[cache.opr[0]]),) + + +def mlir_ceil(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0],), + (args[1],), + lambda ctx, arg0: [linalg.ceil(arg0, outs=[args[1].get_zero_op(ctx)])], + ) + + +def ref_floor( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.floor(var[cache.opr[0]]),) + + +def mlir_floor(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0],), + (args[1],), + lambda ctx, arg0: [linalg.floor(arg0, outs=[args[1].get_zero_op(ctx)])], + ) + + +def ref_erf( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.erf(var[cache.opr[0]]),) + + +def mlir_erf(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0],), + (args[1],), + lambda ctx, arg0: [linalg.erf(arg0, outs=[args[1].get_zero_op(ctx)])], + ) + + +def mlir_log(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0],), + (args[1],), + lambda ctx, arg0: [linalg.log(arg0, outs=[args[1].get_zero_op(ctx)])], + ) + + +def ref_log( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.log(var[cache.opr[0]]),) + + +def mlir_negf(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0],), + (args[1],), + lambda ctx, arg0: [linalg.negf(arg0, outs=[args[1].get_zero_op(ctx)])], + ) + + +def ref_negf( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.neg(var[cache.opr[0]]),) + + +def ref_exp( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.exp(var[cache.opr[0]]),) + + +def mlir_exp(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0],), + (args[1],), + lambda ctx, arg0: [linalg.negf(arg0, outs=[args[1].get_zero_op(ctx)])], + ) + + +def ref_round( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + # torch.round is following the priciple "round half to even" + # we need another implementation + + v = torch.floor(var[cache.opr[0]]) + return (v + torch.where(var[cache.opr[0]] - v >= 0.5, 1, 0),) + + +def mlir_round(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0],), + (args[1],), + lambda ctx, arg0: [linalg.round(arg0, outs=[args[1].get_zero_op(ctx)])], + ) + + +def ref_rsqrt( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.rsqrt(var[cache.opr[0]]),) + + +def mlir_rsqrt(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0],), + (args[1],), + lambda ctx, arg0: [linalg.rsqrt(arg0, outs=[args[1].get_zero_op(ctx)])], + ) + + +def ref_sqrt( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.sqrt(var[cache.opr[0]]),) + + +def mlir_sqrt(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0],), + (args[1],), + lambda ctx, arg0: [linalg.sqrt(arg0, outs=[args[1].get_zero_op(ctx)])], + ) + + +def ref_square( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.square(var[cache.opr[0]]),) + + +def mlir_square(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0],), + (args[1],), + lambda ctx, arg0: [linalg.square(arg0, outs=[args[1].get_zero_op(ctx)])], + ) + + +def ref_tanh( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.tanh(var[cache.opr[0]]),) + + +def mlir_tanh(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0],), + (args[1],), + lambda ctx, arg0: [linalg.tanh(arg0, outs=[args[1].get_zero_op(ctx)])], + ) diff --git a/test/benchgc/src/benchgc/linalg/generic.py b/test/benchgc/src/benchgc/linalg/generic.py new file mode 100644 index 000000000..09e330d5b --- /dev/null +++ b/test/benchgc/src/benchgc/linalg/generic.py @@ -0,0 +1,236 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +from typing import Any, Dict, List, Tuple + +import benchgc.runner +import benchgc.util +import gc_mlir.ir +import torch +from benchgc.mlir.util import MLIRCache + + +def generic_loop( + cache: MLIRCache, + op: gc_mlir.ir.OpView, + depth: int, + iterspace: Dict[str, Tuple[int, int, int]], + affine_from: List[str], + affine_to: List[List[str]], + var: Dict[str, torch.Tensor], + loop_var: Dict[str, torch.Tensor], + result_tensors: Tuple[torch.Tensor, ...], +): + if depth == len(affine_from): + # we need to execute the block here + # we will need to read the block argument name and save it into the cache + + if len(cache.next) == 0: + # region cache + cache.next.append(MLIRCache()) + + block: gc_mlir.ir.Block = op.regions[0].blocks[0] + if len(cache.next[0].next) == 0: + # region->block cache + cache.next[0].next.append(MLIRCache()) + for arg in block.arguments: + cache.next[0].next[0].arg.append(arg.get_name()) + + block_cache = cache.next[0].next[0] + block_arg: Dict[str, torch.Tensor] = {} + for i in range(len(block.arguments)): + index: Tuple[int, ...] = tuple() + aff: List[str] = affine_to[i] + for d in aff: + index = index + (int(loop_var[d].item()),) + + if i + len(op.results) < len(op.regions[0].blocks[0].arguments): + # input argument + block_arg[block_cache.arg[i]] = var[cache.opr[i]][index] + else: + # output argument + block_arg[block_cache.arg[i]] = result_tensors[ + i + len(op.results) - len(block.arguments) + ][index] + + res: Tuple[Any, ...] = benchgc.runner.dfs_block( + cache.next[0].next[0], block, var | loop_var | block_arg + ) + + # perform the yield operation + for i in range(len(op.results)): + idx = -1 - i + aff: List[str] = affine_to[idx] + index: Tuple[int, ...] = tuple() + for d in aff: + index = index + (int(loop_var[d].item()),) + result_tensors[idx][index] = res[idx] + else: + it = iterspace[affine_from[depth]] + for i in range(it[0], it[1], it[2]): + loop_var[affine_from[depth]][0] = i + generic_loop( + cache, + op, + depth + 1, + iterspace, + affine_from, + affine_to, + var, + loop_var, + result_tensors, + ) + + +def ref_generic( + cache: MLIRCache, op: gc_mlir.ir.OpView, tensors: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + affine_from: List[str] = [] + affine_to: List[List[str]] = [] + + for affine in op.attributes["indexing_maps"]: + aff = str(affine) + affine_from = aff[aff.find("<(") + 2 : aff.find(") ->")].split(", ") + affine_to.append(aff[aff.find("-> (") + 4 : aff.find(")>")].split(", ")) + + # try to find the iteration space + # TODO: support affine expression + + iterspace: Dict[str, Tuple[int, int, int]] = {} + operands: List[gc_mlir.ir.OpOperand] = list(op.operands) + + loop_var: Dict[str, torch.Tensor] = {} + for d in affine_from: + iterspace[d] = (0, 0, 1) + loop_var[d] = torch.zeros(size=[1], dtype=torch.int) + + for i in range(len(operands)): + for j in range(len(operands[i].type.shape)): + iterspace[affine_to[i][j]] = (0, operands[i].type.shape[j], 1) + + result_tensors: Tuple[torch.Tensor, ...] = tuple() + # create the buffer for result tensors + for i in range(len(op.results)): + result_tensors = result_tensors + (tensors[cache.opr[-1 - i]].clone(),) + + generic_loop( + cache, + op, + 0, + iterspace, + affine_from, + affine_to, + tensors, + loop_var, + result_tensors, + ) + return result_tensors + + +def reduce_loop( + cache: MLIRCache, + op: gc_mlir.ir.OpView, + depth: int, + in_shape: List[int], + var: Dict[str, torch.Tensor], + in_idx: List[int], + out_idx: List[int], + reduced_axis: int, + result_tensor: torch.Tensor, +): + if depth == len(in_shape): + # we need to execute the block here + # we will need to read the block argument name and save it into the cache + + block: gc_mlir.ir.Block = op.regions[0].blocks[0] + + if len(cache.next) == 0: + # region cache + cache.next.append(MLIRCache()) + if len(cache.next[0].next) == 0: + # region->block cache + cache.next[0].next.append(MLIRCache()) + for arg in block.arguments: + cache.next[0].next[0].arg.append(arg.get_name()) + + block_arg: Dict[str, torch.Tensor] = { + # set input + cache.next[0].next[0].arg[0]: var[cache.opr[0]][tuple(in_idx)], + # set output + cache.next[0].next[0].arg[1]: result_tensor[tuple(out_idx)], + } + + res: Tuple[torch.Tensor, ...] = benchgc.runner.dfs_block( + cache.next[0].next[0], op.regions[0].blocks[0], var | block_arg + ) + + # perform the yield operation + result_tensor[tuple(out_idx)] = res[0] + else: + dimensions: gc_mlir.ir.DenseI64ArrayAttr = op.attributes["dimensions"] + reduce_axis: bool = depth in list(dimensions) + + for i in range(in_shape[depth]): + if reduce_axis: + in_idx[depth] = i + reduce_loop( + cache, + op, + depth + 1, + in_shape, + var, + in_idx, + out_idx, + reduced_axis + 1, + result_tensor, + ) + else: + in_idx[depth] = i + out_idx[depth - reduce_axis] = i + reduce_loop( + cache, + op, + depth + 1, + in_shape, + var, + in_idx, + out_idx, + reduced_axis, + result_tensor, + ) + + +def ref_reduce( + cache: MLIRCache, op: gc_mlir.ir.OpView, tensors: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + # create the buffer for result tensors + tensors[cache.res[0]] = tensors[cache.opr[-1]].clone() + in_shape: List[int] = list(op.operands[0].type.shape) + out_shape: List[int] = list(op.result.type.shape) + + result_tensor: torch.Tensor = tensors[cache.opr[-1]].clone() + reduce_loop( + cache, + op, + 0, + in_shape, + tensors, + [0] * len(in_shape), + [0] * len(out_shape), + 0, + result_tensor, + ) + return (result_tensor,) diff --git a/test/benchgc/src/benchgc/linalg/matmul.py b/test/benchgc/src/benchgc/linalg/matmul.py new file mode 100644 index 000000000..9efde9612 --- /dev/null +++ b/test/benchgc/src/benchgc/linalg/matmul.py @@ -0,0 +1,317 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import argparse +from typing import Dict, List, Tuple + +import gc_mlir.ir +import torch +from benchgc.arg import Arg +from benchgc.mlir.module import init_module +from benchgc.mlir.util import MLIRCache +from gc_mlir.dialects import linalg +from gc_mlir.dialects.linalg.opdsl.lang.comprehension import TypeFnType + + +def ref_batch_matmul( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.matmul(var[cache.opr[0]], var[cache.opr[1]]),) + + +def mlir_batch_matmul(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.batch_matmul(arg0, arg1, outs=[args[2].get_zero_op(ctx)]) + ], + ) + + +def ref_batch_matmul_transpose_a( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.bmm(var[cache.opr[0]].transpose(-1, -2), var[cache.opr[1]]),) + + +def mlir_batch_matmul_transpose_a( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.batch_matmul_transpose_a(arg0, arg1, outs=[args[2].get_zero_op(ctx)]) + ], + ) + + +def ref_batch_matmul_transpose_b( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.bmm(var[cache.opr[0]], var[cache.opr[1]].transpose(-1, -2)),) + + +def mlir_batch_matmul_transpose_b( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.batch_matmul_transpose_b(arg0, arg1, outs=[args[2].get_zero_op(ctx)]) + ], + ) + + +def ref_batch_matvec( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + # pytorch does not support bmv + return ( + torch.matmul(var[cache.opr[0]], var[cache.opr[1]].unsqueeze(-1)).squeeze(-1), + ) + + +def mlir_batch_matvec(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.batch_matvec(arg0, arg1, outs=[args[2].get_zero_op(ctx)]) + ], + ) + + +def ref_batch_mmt4d( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + # [B, m, k, m0, k0] -> [B, m, m0, k, k0] + _src = var[cache.opr[0]].permute([0, 1, 3, 2, 4]).contiguous() + # [B, n, k, n0, k0] -> [B, k, k0, n, n0] + _wei = var[cache.opr[1]].permute([0, 2, 4, 1, 3]).contiguous() + + # [B, m, m0, k, k0] -> [B, M, K] + src = _src.reshape( + [_src.shape[0], _src.shape[1] * _src.shape[2], _src.shape[3] * _src.shape[4]] + ) + # [B, k, k0, n, n0] -> [B, K, N] + wei = _wei.reshape( + [_wei.shape[0], _wei.shape[1] * _wei.shape[2], _wei.shape[3] * _wei.shape[4]] + ) + + dst = torch.bmm(src, wei) + # [B, M, N] -> [B, m, m0, n, n0] + dst = dst.reshape( + [dst.shape[0], _src.shape[1], _src.shape[2], _wei.shape[-2], _wei.shape[-1]] + ) + + # [B, m, m0, n, n0] -> [B, m, n, m0, n0] + return (dst.transpose(2, 3).contiguous(),) + + +def mlir_batch_mmt4d(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.batch_mmt4d(arg0, arg1, outs=[args[2].get_zero_op(ctx)]) + ], + ) + + +def ref_batch_reduce_matmul( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return ( + torch.addbmm( + input=torch.zeros(tuple()), + batch1=var[cache.opr[0]], + batch2=var[cache.opr[1]], + beta=0, + alpha=1, + ), + ) + + +def mlir_batch_reduce_matmul( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.batch_reduce_matmul(arg0, arg1, outs=[args[2].get_zero_op(ctx)]) + ], + ) + + +def ref_batch_vecmat( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return ( + torch.matmul(var[cache.opr[0]].unsqueeze(-2), var[cache.opr[1]]).squeeze(-2), + ) + + +def mlir_batch_vecmat(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.batch_vecmat(arg0, arg1, outs=[args[2].get_zero_op(ctx)]) + ], + ) + + +def ref_dot( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.dot(var[cache.opr[0]], var[cache.opr[1]]),) + + +def mlir_dot(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.dot(arg0, arg1, outs=[args[2].get_zero_op(ctx)]) + ], + ) + + +def ref_matmul( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.mm(var[cache.opr[0]], var[cache.opr[1]]),) + + +def mlir_matmul(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.matmul( + arg0, arg1, outs=[args[2].get_zero_op(ctx)], cast=TypeFnType(flags.cast) + ) + ], + ) + + +def ref_matmul_transpose_a( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.mm(var[cache.opr[0]].transpose(-1, -2), var[cache.opr[1]]),) + + +def mlir_matmul_transpose_a( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.matmul_transpose_a( + arg0, arg1, outs=[args[2].get_zero_op(ctx)], cast=TypeFnType(flags.cast) + ) + ], + ) + + +def ref_matmul_transpose_b( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.mm(var[cache.opr[0]], var[cache.opr[1]].transpose(-1, -2)),) + + +def mlir_matmul_transpose_b( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.matmul_transpose_b( + arg0, arg1, outs=[args[2].get_zero_op(ctx)], cast=TypeFnType(flags.cast) + ) + ], + ) + + +def ref_matvec( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.mv(var[cache.opr[0]], var[cache.opr[1]]),) + + +def mlir_matvec(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.matvec(arg0, arg1, outs=[args[2].get_zero_op(ctx)]) + ], + ) + + +def ref_mmt4d( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + # [m, k, m0, k0] -> [m, m0, k, k0] + _src = var[cache.opr[0]].permute([0, 2, 1, 3]).contiguous() + # [n, k, n0, k0] -> [k, k0, n, n0] + _wei = var[cache.opr[1]].permute([1, 3, 0, 2]).contiguous() + + # [m, m0, k, k0] -> [M, K] + src = _src.reshape([_src.shape[0] * _src.shape[1], _src.shape[2] * _src.shape[3]]) + # [k, k0, n, n0] -> [K, N] + wei = _wei.reshape([_wei.shape[0] * _wei.shape[1], _wei.shape[2] * _wei.shape[3]]) + + dst = torch.mm(src, wei) + # [M, N] -> [m, m0, n, n0] + dst = dst.reshape([_src.shape[0], _src.shape[1], _wei.shape[-2], _wei.shape[-1]]) + + # [m, m0, n, n0] -> [m, n, m0, n0] + return (dst.transpose(1, 2).contiguous(),) + + +def mlir_mmt4d(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.mmt4d(arg0, arg1, outs=[args[2].get_zero_op(ctx)]) + ], + ) + + +def ref_vecmat( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return ( + torch.matmul(var[cache.opr[0]].unsqueeze(-2), var[cache.opr[1]]).squeeze(-2), + ) + + +def mlir_vecmat(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.vecmat(arg0, arg1, outs=[args[2].get_zero_op(ctx)]) + ], + ) diff --git a/test/benchgc/src/benchgc/linalg/misc.py b/test/benchgc/src/benchgc/linalg/misc.py new file mode 100644 index 000000000..cf672956c --- /dev/null +++ b/test/benchgc/src/benchgc/linalg/misc.py @@ -0,0 +1,97 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import argparse +import copy +from typing import Dict, List, Tuple + +import benchgc.util +import gc_mlir.ir +import torch +from benchgc.arg import Arg +from benchgc.mlir.module import init_module +from benchgc.mlir.util import MLIRCache +from gc_mlir._mlir_libs._mlir.ir import DenseI64ArrayAttr +from gc_mlir.dialects import linalg +from gc_mlir.dialects.linalg.opdsl.lang.comprehension import TypeFnType + + +# 1. use to reshape to match ndim +# 2. perform broadcast +def ref_broadcast( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + dst_shape: List[int] = op.results[0].type.shape + tmp_shape = copy.copy(dst_shape) + dimensions: DenseI64ArrayAttr = op.attributes["dimensions"] + for d in dimensions: + tmp_shape[d] = 1 + + return (var[cache.opr[0]].reshape(tmp_shape).broadcast_to(dst_shape).contiguous(),) + + +def mlir_broadcast(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + + return init_module( + (args[0],), + (args[1],), + lambda ctx, arg0: [ + linalg.broadcast( + arg0, outs=[args[1].get_zero_op(ctx)], dimensions=flags.dimensions + ) + ], + ) + + +def ref_fill( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.full(tuple(op.results[0].type.shape), var[cache.opr[0]]),) + + +def mlir_fill(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0],), + (args[1],), + lambda ctx, arg0: [ + linalg.fill( + arg0, outs=[args[1].get_zero_op(ctx)], dimensions=flags.dimensions + ) + ], + ) + + +def ref_copy( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return ( + var[cache.opr[0]] + .to(benchgc.util.get_dtype(str(op.result.type.element_type))) + .clone(), + ) + + +def mlir_copy(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + + return init_module( + (args[0],), + (args[1],), + lambda ctx, arg0: [ + linalg.copy( + arg0, outs=[args[1].get_zero_op(ctx)], cast=TypeFnType(flags.cast) + ) + ], + ) diff --git a/test/benchgc/src/benchgc/linalg/pool.py b/test/benchgc/src/benchgc/linalg/pool.py new file mode 100644 index 000000000..9779256df --- /dev/null +++ b/test/benchgc/src/benchgc/linalg/pool.py @@ -0,0 +1,489 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import argparse +from typing import Dict, List, Tuple + +import gc_mlir.ir +import torch +from benchgc.arg import Arg +from benchgc.mlir.module import init_module +from benchgc.mlir.util import MLIRCache +from gc_mlir.dialects import linalg + + +def ref_pooling_nchw_max( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + return ( + torch.max_pool2d( + var[cache.opr[0]], + kernel_size=var[cache.opr[1]].shape, + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + ), + ) + + +def mlir_pooling_nchw_max( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.pooling_nchw_max( + arg0, + arg1, + outs=[args[2].get_min_value_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_pooling_nchw_sum( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + + # pytorch does not support pooling on sum + # avg_pool2d or lp_pool2d with p = 1 does not support dilation + # we will use depthwise convolution with kernel equals to 1 to calculate the sum + + # FIXME: improve the code if pytorch support the sum pooling with dilation + + channel = var[cache.opr[0]].shape[1] + kernel = var[cache.opr[1]] + return ( + torch.conv2d( + var[cache.opr[0]], + torch.ones(channel, 1, kernel.shape[0], kernel.shape[1]), + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + groups=channel, + ), + ) + + +def mlir_pooling_nchw_sum( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.pooling_nchw_sum( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_pooling_ncw_max( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + return ( + torch.max_pool1d( + var[cache.opr[0]], + kernel_size=var[cache.opr[1]].shape, + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + ), + ) + + +def mlir_pooling_ncw_max( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.pooling_ncw_max( + arg0, + arg1, + outs=[args[2].get_min_value_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_pooling_ncw_sum( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + + # pytorch does not support pooling on sum + # avg_pool1d or lp_pool1d with p = 1 does not support dilation + # we will use depthwise convolution with kernel equals to 1 to calculate the sum + + # FIXME: improve the code if pytorch support the sum pooling with dilation + + channel = var[cache.opr[0]].shape[1] + kernel = var[cache.opr[1]] + return ( + torch.conv1d( + var[cache.opr[0]], + torch.ones(channel, 1, kernel.shape[0]), + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + groups=channel, + ), + ) + + +def mlir_pooling_ncw_sum( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.pooling_ncw_sum( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_pooling_ndhwc_max( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + return ( + torch.max_pool3d( + var[cache.opr[0]].permute([0, -1, 1, 2, 3]), + kernel_size=var[cache.opr[1]].shape, + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + ) + .permute([0, 2, 3, 4, 1]) + .contiguous(), + ) + + +def mlir_pooling_ndhwc_max( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.pooling_ndhwc_max( + arg0, + arg1, + outs=[args[2].get_min_value_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_pooling_ndhwc_sum( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + + # pytorch does not support pooling on sum + # avg_pool3d or lp_pool3d with p = 1 does not support dilation + # we will use depthwise convolution with kernel equals to 1 to calculate the sum + + # FIXME: improve the code if pytorch support the sum pooling with dilation + + channel = var[cache.opr[0]].shape[-1] + kernel = var[cache.opr[1]] + return ( + torch.conv3d( + var[cache.opr[0]].permute([0, -1, 1, 2, 3]), + torch.ones(channel, 1, kernel.shape[0], kernel.shape[1], kernel.shape[2]), + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + groups=channel, + ) + .permute([0, 2, 3, 4, 1]) + .contiguous(), + ) + + +def mlir_pooling_ndhwc_sum( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.pooling_ndhwc_sum( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_pooling_nhwc_max( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + return ( + torch.max_pool2d( + var[cache.opr[0]].permute([0, -1, 1, 2]), + kernel_size=var[cache.opr[1]].shape, + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + ) + .permute([0, 2, 3, 1]) + .contiguous(), + ) + + +def mlir_pooling_nhwc_max( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.pooling_nhwc_max( + arg0, + arg1, + outs=[args[2].get_min_value_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_pooling_nhwc_sum( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + + # pytorch does not support pooling on sum + # avg_pool2d or lp_pool2d with p = 1 does not support dilation + # we will use depthwise convolution with kernel equals to 1 to calculate the sum + + # FIXME: improve the code if pytorch support the sum pooling with dilation + + channel = var[cache.opr[0]].shape[-1] + kernel = var[cache.opr[1]] + return ( + torch.conv2d( + var[cache.opr[0]].permute([0, -1, 1, 2]), + torch.ones(channel, 1, kernel.shape[0], kernel.shape[1]), + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + groups=channel, + ) + .permute([0, 2, 3, 1]) + .contiguous(), + ) + + +def mlir_pooling_nhwc_sum( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.pooling_nhwc_sum( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_pooling_nhwc_min( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + return ( + torch.max_pool2d( + var[cache.opr[0]].permute([0, -1, 1, 2]).neg(), + kernel_size=var[cache.opr[1]].shape, + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + ) + .permute([0, 2, 3, 1]) + .neg() + .contiguous(), + ) + + +def mlir_pooling_nhwc_min( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.pooling_nhwc_min( + arg0, + arg1, + outs=[args[2].get_max_value_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_pooling_nwc_max( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + return ( + torch.max_pool1d( + var[cache.opr[0]].permute([0, -1, 1]), + kernel_size=var[cache.opr[1]].shape, + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + ) + .permute([0, 2, 1]) + .contiguous(), + ) + + +def mlir_pooling_nwc_max( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.pooling_nwc_max( + arg0, + arg1, + outs=[args[2].get_min_value_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_pooling_nwc_min( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + return ( + torch.max_pool1d( + var[cache.opr[0]].permute([0, -1, 1]).neg(), + kernel_size=var[cache.opr[1]].shape, + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + ) + .permute([0, 2, 1]) + .contiguous() + .neg(), + ) + + +def mlir_pooling_nwc_min( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.pooling_nwc_min( + arg0, + arg1, + outs=[args[2].get_max_value_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_pooling_nwc_sum( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + + # pytorch does not support pooling on sum + # avg_pool3d or lp_pool3d with p = 1 does not support dilation + # we will use depthwise convolution with kernel equals to 1 to calculate the sum + + # FIXME: improve the code if pytorch support the sum pooling with dilation + + channel = var[cache.opr[0]].shape[-1] + kernel = var[cache.opr[1]] + return ( + torch.conv1d( + var[cache.opr[0]].permute([0, -1, 1]), + torch.ones(channel, 1, kernel.shape[0]), + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + groups=channel, + ) + .permute([0, 2, 1]) + .contiguous(), + ) + + +def mlir_pooling_nwc_sum( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.pooling_nwc_sum( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) diff --git a/test/benchgc/src/benchgc/linalg/softmax.py b/test/benchgc/src/benchgc/linalg/softmax.py new file mode 100644 index 000000000..20ed39fcb --- /dev/null +++ b/test/benchgc/src/benchgc/linalg/softmax.py @@ -0,0 +1,47 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import argparse +from typing import Dict, List, Tuple + +import gc_mlir.ir +import torch +from benchgc.arg import Arg +from benchgc.mlir.module import init_module +from benchgc.mlir.util import MLIRCache +from gc_mlir.dialects import linalg + + +def ref_softmax( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + dimension: gc_mlir.ir.IntegerAttr = op.attributes["dimension"] + return (torch.softmax(var[cache.opr[0]], dimension.value),) + + +def mlir_softmax(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0],), + (args[1],), + lambda ctx, arg0: [ + linalg.softmax( + result=[args[1].get_ranked_tensor_type(ctx)], + input=arg0, + output=args[1].get_zero_op(ctx), + dimension=flags.dimension, + ) + ], + ) diff --git a/test/benchgc/src/benchgc/mlir/CMakeLists.txt b/test/benchgc/src/benchgc/mlir/CMakeLists.txt new file mode 100644 index 000000000..5f8d589b4 --- /dev/null +++ b/test/benchgc/src/benchgc/mlir/CMakeLists.txt @@ -0,0 +1,22 @@ +################################################################################ +# Copyright (C) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# SPDX-License-Identifier: Apache-2.0 +################################################################################ + + +file(GLOB PYTHON_SCRIPTS "*.py") +foreach(PY_SCRIPT ${PYTHON_SCRIPTS}) + configure_file(${PY_SCRIPT} ${CMAKE_BINARY_DIR}/test/benchgc/src/benchgc/mlir/ COPYONLY) +endforeach() diff --git a/test/benchgc/src/benchgc/mlir/__init__.py b/test/benchgc/src/benchgc/mlir/__init__.py new file mode 100644 index 000000000..4d3e897ce --- /dev/null +++ b/test/benchgc/src/benchgc/mlir/__init__.py @@ -0,0 +1,15 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ diff --git a/test/benchgc/src/benchgc/mlir/arg.py b/test/benchgc/src/benchgc/mlir/arg.py new file mode 100644 index 000000000..364b9d92c --- /dev/null +++ b/test/benchgc/src/benchgc/mlir/arg.py @@ -0,0 +1,174 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import ctypes +from typing import Any, List + +import benchgc.util +import gc_mlir.dialects.arith +import gc_mlir.dialects.linalg +import gc_mlir.dialects.tensor +import gc_mlir.ir +import torch +from benchgc.mlir.util import dtype_to_ctype, str_to_mlir_dtype, str_to_mlir_typed_attr + + +# scalar should give a address +# map torch.Tensor -> memref +# map int address -> scalar value +def get_mlir_args(args: List[torch.Tensor | int]): + mlir_args: List[Any] = [] + + for arg in args: + if isinstance(arg, torch.Tensor): + mlir_args.append(ctypes.pointer(ctypes.pointer(get_md(arg)))) + else: + mlir_args.append(ctypes.c_void_p(arg)) + + return mlir_args + + +def get_md(tensor: torch.Tensor): + if tensor.ndim == 0: + + class _0dMemrefDescriptor(ctypes.Structure): + _fields_ = [ + ("allocated", ctypes.c_longlong), + ("aligned", ctypes.POINTER(dtype_to_ctype(tensor.dtype))), + ("offset", ctypes.c_longlong), + ] + + md = _0dMemrefDescriptor() + else: + ctype_shape = ctypes.c_longlong * tensor.ndim + ctype_strides = ctypes.c_longlong * tensor.ndim + + class _ndMemrefDescriptor(ctypes.Structure): + _fields_ = [ + ("allocated", ctypes.c_longlong), + ("aligned", ctypes.POINTER(dtype_to_ctype(tensor.dtype))), + ("offset", ctypes.c_longlong), + ("shape", ctype_shape), + ("strides", ctype_strides), + ] + + md = _ndMemrefDescriptor() + md.shape = ctype_shape(*tensor.shape) + md.strides = ctype_strides(*tensor.stride()) + + md.allocated = tensor.data_ptr() + md.aligned = ctypes.cast( + ctypes.c_void_p(tensor.data_ptr()), ctypes.POINTER(dtype_to_ctype(tensor.dtype)) + ) + md.offset = ctypes.c_longlong(0) + return md + + +class MLIRArg: + dtype: str + shape: List[int] + + scalar: bool + + def __init__(self) -> None: + self.dtype = "" + + # md format: + # 0d memref/tensor: 0xf32 + # nd memref/tensor: 2x3xf32 + # scalar: f32 + def set_md(self, md: str): + splited: List[str] = md.split("x") + self.dtype = splited[-1] + self.shape = [] + + for dim in splited[:-1]: + self.shape.append(int(dim)) + self.set_scalar() + + def set_scalar(self): + # use 0xf32 to represent memref + # use f32 to represent f32 + if self.shape == [0]: + self.shape = [] + self.scalar = False + elif self.shape == []: + self.scalar = True + else: + self.scalar = False + + def nelem(self) -> int: + if self.scalar or self.shape == [] or self.shape[0] == 0: + return 1 + ret: int = 1 + for dim in self.shape: + ret = ret * dim + return ret + + def get_mlir_type(self, ctx: gc_mlir.ir.Context) -> gc_mlir.ir.Type: + if self.scalar: + return str_to_mlir_dtype(ctx, self.dtype) + else: + return gc_mlir.ir.RankedTensorType.get( + self.shape, str_to_mlir_dtype(ctx, self.dtype) + ) + + def get_ranked_tensor_type( + self, ctx: gc_mlir.ir.Context + ) -> gc_mlir.ir.RankedTensorType: + return gc_mlir.ir.RankedTensorType.get( + self.shape, str_to_mlir_dtype(ctx, self.dtype) + ) + + def get_constant_op( + self, ctx: gc_mlir.ir.Context, cst: Any + ) -> gc_mlir.dialects.tensor.OpView: + zero = gc_mlir.dialects.arith.ConstantOp( + value=str_to_mlir_typed_attr(ctx, self.dtype, cst), + result=str_to_mlir_dtype(ctx, self.dtype), + ) + if self.scalar: + return zero + else: + return gc_mlir.dialects.linalg.fill( + zero, + outs=[ + gc_mlir.dialects.tensor.EmptyOp( + self.shape, str_to_mlir_dtype(ctx, self.dtype) + ) + ], + ) + + def get_zero_op(self, ctx: gc_mlir.ir.Context) -> gc_mlir.dialects.tensor.OpView: + return self.get_constant_op(ctx, 0) + + def get_max_value_op( + self, ctx: gc_mlir.ir.Context + ) -> gc_mlir.dialects.tensor.OpView: + dtype = benchgc.util.get_dtype(self.dtype) + if dtype.is_floating_point: + return self.get_constant_op(ctx, torch.finfo(dtype).max) + else: + return self.get_constant_op(ctx, torch.iinfo(dtype).max) + + def get_min_value_op( + self, ctx: gc_mlir.ir.Context + ) -> gc_mlir.dialects.tensor.OpView: + dtype = benchgc.util.get_dtype(self.dtype) + if dtype.is_floating_point: + return self.get_constant_op(ctx, torch.finfo(dtype).min) + else: + return self.get_constant_op(ctx, torch.iinfo(dtype).min) diff --git a/test/benchgc/src/benchgc/mlir/module.py b/test/benchgc/src/benchgc/mlir/module.py new file mode 100644 index 000000000..806c9d8b7 --- /dev/null +++ b/test/benchgc/src/benchgc/mlir/module.py @@ -0,0 +1,48 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +from typing import Callable, List, Tuple + +import gc_mlir.dialects.tensor +import gc_mlir.ir +from benchgc.mlir.arg import MLIRArg +from gc_mlir.dialects import func + + +def init_module( + inputs: Tuple[MLIRArg, ...], + outputs: Tuple[MLIRArg, ...], + op_func: Callable[ + [gc_mlir.ir.Context, Tuple[gc_mlir.ir.BlockArgument, ...]], + List[gc_mlir.ir.OpResult], + ], +) -> gc_mlir.ir.Module: + with gc_mlir.ir.Context() as ctx, gc_mlir.ir.Location.unknown(): + module = gc_mlir.ir.Module.create() + with gc_mlir.ir.InsertionPoint(module.body): + f = func.FuncOp( + name="entry", + type=gc_mlir.ir.FunctionType.get( + inputs=[x.get_mlir_type(ctx) for x in inputs], + results=[x.get_mlir_type(ctx) for x in outputs], + ), + ) + f.attributes["llvm.emit_c_interface"] = gc_mlir.ir.UnitAttr.get() + + with gc_mlir.ir.InsertionPoint(f.add_entry_block()): + block_args = f.entry_block.arguments + func.ReturnOp(op_func(ctx, *block_args)) + return module diff --git a/test/benchgc/src/benchgc/mlir/util.py b/test/benchgc/src/benchgc/mlir/util.py new file mode 100644 index 000000000..24169bca1 --- /dev/null +++ b/test/benchgc/src/benchgc/mlir/util.py @@ -0,0 +1,111 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import ctypes +from typing import Any, List + +import gc_mlir.ir +import torch +from gc_mlir.dialects import func + +# only python 3.11 support +# from typing import Self + + +def get_entry(module: gc_mlir.ir.Module, entry: str = '"entry"') -> func.FuncOp: + for op in module.operation.opview.regions[0].blocks[0].operations: + if str(op.name) == entry: + return op + raise Exception(f"entry function {entry} is not found at the top level") + + +# calling python binding consumes a lot of time e.g. get_name() +# we need to cache some result to avoid duplicate call +class MLIRCache: + # operand name cache + opr: List[str] + # result name cache + res: List[str] + # argument name cache + arg: List[str] + # next hierarchy + next = [] # List[Self] + + def __init__(self): + self.opr = [] + self.res = [] + self.arg = [] + self.next = [] + + +def dtype_to_ctype(dtype: torch.dtype): + if dtype == torch.float32: + return ctypes.c_float + elif dtype == torch.float64: + return ctypes.c_double + elif dtype == torch.int32: + return ctypes.c_int + elif dtype == torch.int64: + return ctypes.c_longlong + elif dtype == torch.uint8: + return ctypes.c_ubyte + elif dtype == torch.int8: + return ctypes.c_byte + elif dtype == torch.int16 or dtype == torch.bfloat16 or torch.float16: + return ctypes.c_short + elif dtype == torch.bool: + return ctypes.c_bool + else: + raise ValueError(f"Unsupported torch dtype: {dtype}") + + +def str_to_mlir_dtype(ctx: gc_mlir.ir.Context, dtype: str) -> gc_mlir.ir.Type: + if dtype == "f32": + return gc_mlir.ir.F32Type.get(ctx) + elif dtype == "f64": + return gc_mlir.ir.F64Type.get(ctx) + elif dtype == "f16": + return gc_mlir.ir.F16Type.get(ctx) + elif dtype == "bf16": + return gc_mlir.ir.BF16Type.get(ctx) + elif dtype == "u8": + return gc_mlir.ir.IntegerType.get_unsigned(8, ctx) + elif dtype == "s8": + return gc_mlir.ir.IntegerType.get_signed(8, ctx) + elif dtype == "boolean": + return gc_mlir.ir.IntegerType.get_unsigned(1, ctx) + elif dtype == "f8_e4m3": + return gc_mlir.ir.Float8E4M3FNType.get(ctx) + elif dtype == "f8_e5m2": + return gc_mlir.ir.Float8E5M2Type.get(ctx) + elif dtype == "s32": + return gc_mlir.ir.IntegerType.get_signed(32, ctx) + else: + raise Exception(f"data type not support: {dtype}") + + +def str_to_mlir_typed_attr( + ctx: gc_mlir.ir.Context, dtype: str, value: Any +) -> gc_mlir.ir.Attribute: + mlir_dtype = str_to_mlir_dtype(ctx, dtype) + if dtype in ["f32", "f64", "bf16", "f16", "f8_e4m3", "f8_e5m2"]: + return gc_mlir.ir.FloatAttr.get(mlir_dtype, value) + elif dtype in ["u8", "s8", "s32"]: + return gc_mlir.ir.IntegerAttr.get(mlir_dtype, value) + elif dtype == "boolean": + return gc_mlir.ir.BoolAttr.get(value) + else: + raise Exception(f"data type not support: {dtype}") diff --git a/test/benchgc/src/benchgc/runner.py b/test/benchgc/src/benchgc/runner.py new file mode 100644 index 000000000..80178baa8 --- /dev/null +++ b/test/benchgc/src/benchgc/runner.py @@ -0,0 +1,109 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +from typing import Dict, Tuple + +import gc_mlir._mlir_libs +import gc_mlir.dialects +import gc_mlir.dialects.func +import gc_mlir.ir +import torch +from benchgc.arith import ref_op as arith_ref_op +from benchgc.linalg import ref_op as linalg_ref_op +from benchgc.mlir.util import MLIRCache +from benchgc.tensor import ref_op as tensor_ref_op + + +def dfs_op( + cache: MLIRCache, op: gc_mlir.ir.OpView, tensors: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + + dialect_call: str = str(op.name) + if dialect_call in ["func.return", "linalg.yield"]: + ret: Tuple[torch.Tensor, ...] = tuple() + for name in cache.opr: + ret = ret + (tensors[name],) + return ret + if dialect_call.startswith("linalg"): + ref_op = linalg_ref_op + elif dialect_call.startswith("tensor"): + ref_op = tensor_ref_op + elif dialect_call.startswith("arith"): + ref_op = arith_ref_op + else: + build_cache = len(cache.next) == 0 + for i in range(len(op.regions)): + if build_cache: + # we do not need to cache things for region + # keep an empty cache + cache.next.append(MLIRCache()) + ret = dfs_region(cache.next[i], op.regions[i], tensors) + if len(ret) != 0: + return ret + return tuple() + + dialect_op: str = dialect_call.split(".")[1] + if dialect_op not in ref_op: + raise Exception(f"unknown op call {dialect_call}") + ref_func = ref_op[dialect_op] + for i, res in enumerate(ref_func(cache, op, tensors)): + tensors[cache.res[i]] = res + return tuple() + + +def dfs_region( + cache: MLIRCache, region: gc_mlir.ir.Region, tensors: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + build_cache = len(cache.next) == 0 + for i in range(len(region.blocks)): + if build_cache: + _cache = MLIRCache() + # we need to cache argument name for block object + for arg in region.blocks[i].arguments: + _cache.arg.append(arg.get_name()) + cache.next.append(_cache) + ret = dfs_block(cache.next[i], region.blocks[i], tensors) + if len(ret) != 0: + return ret + return tuple() + + +def dfs_block( + cache: MLIRCache, block: gc_mlir.ir.Block, tensors: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + build_cache = len(cache.next) == 0 + for i in range(len(block.operations)): + if build_cache: + _cache = MLIRCache() + # we need to cache operand name and result name + for opr in block.operations[i].operands: + _cache.opr.append(opr.get_name()) + + for res in block.operations[i].results: + _cache.res.append(res.get_name()) + cache.next.append(_cache) + + ret = dfs_op(cache.next[i], block.operations[i], tensors) + if len(ret) != 0: + return ret + return tuple() + + +def ref_run( + entry: gc_mlir.dialects.func.FuncOp, tensors: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + # cache some information of block & op + return dfs_op(MLIRCache(), entry, tensors) diff --git a/test/benchgc/src/benchgc/tensor/CMakeLists.txt b/test/benchgc/src/benchgc/tensor/CMakeLists.txt new file mode 100644 index 000000000..7b1b990dc --- /dev/null +++ b/test/benchgc/src/benchgc/tensor/CMakeLists.txt @@ -0,0 +1,22 @@ +################################################################################ +# Copyright (C) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# SPDX-License-Identifier: Apache-2.0 +################################################################################ + + +file(GLOB PYTHON_SCRIPTS "*.py") +foreach(PY_SCRIPT ${PYTHON_SCRIPTS}) + configure_file(${PY_SCRIPT} ${CMAKE_BINARY_DIR}/test/benchgc/src/benchgc/tensor/ COPYONLY) +endforeach() diff --git a/test/benchgc/src/benchgc/tensor/__init__.py b/test/benchgc/src/benchgc/tensor/__init__.py new file mode 100644 index 000000000..2f8bc98a4 --- /dev/null +++ b/test/benchgc/src/benchgc/tensor/__init__.py @@ -0,0 +1,45 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import argparse +import importlib +from typing import Callable, Dict, Tuple + +import gc_mlir.ir +import torch +from benchgc.arg import Arg +from benchgc.mlir.util import MLIRCache + +ref_op: Dict[ + str, + Callable[ + [MLIRCache, gc_mlir.ir.OpView, Dict[str, torch.Tensor]], + Tuple[torch.Tensor, ...], + ], +] = {} +mlir_op: Dict[ + str, Callable[[argparse.Namespace, Dict[str, Arg]], gc_mlir.ir.Module] +] = {} + +for dri in ["basic", "shape"]: + mod = importlib.import_module(f"benchgc.tensor.{dri}") + for key in mod.__dict__: + if key.startswith("ref_"): + op: str = key.removeprefix("ref_") + ref_op[op] = mod.__dict__[key] + if key.startswith("mlir_"): + op: str = key.removeprefix("mlir_") + mlir_op[op] = mod.__dict__[key] diff --git a/test/benchgc/src/benchgc/tensor/basic.py b/test/benchgc/src/benchgc/tensor/basic.py new file mode 100644 index 000000000..eb56aafbc --- /dev/null +++ b/test/benchgc/src/benchgc/tensor/basic.py @@ -0,0 +1,33 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +from typing import Dict, Tuple + +import benchgc.util +import gc_mlir.ir +import torch +from benchgc.mlir.util import MLIRCache + + +def ref_empty( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return ( + torch.zeros( + size=op.results[0].type.shape, + dtype=benchgc.util.get_dtype(str(op.results[0].type.element_type)), + ), + ) diff --git a/test/benchgc/src/benchgc/tensor/shape.py b/test/benchgc/src/benchgc/tensor/shape.py new file mode 100644 index 000000000..18d9fbb2c --- /dev/null +++ b/test/benchgc/src/benchgc/tensor/shape.py @@ -0,0 +1,59 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +from typing import Dict, List, Tuple + +import gc_mlir.ir +import torch +from benchgc.mlir.util import MLIRCache + + +def ref_collapse_shape( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + # permute axis and do reshape + reassociation: gc_mlir.ir.ArrayAttr = op.attributes["reassociation"] + permutation: List[int] = [] + shape: List[int] = [] + for outdim in reassociation: + d: int = 1 + for indim in outdim: + permutation.append(int(indim)) + d = d * int(op.operands[0].type.shape[int(indim)]) + shape.append(d) + return ( + torch.permute(var[cache.opr[0]], tuple(permutation)) + .contiguous() + .reshape(shape) + .contiguous(), + ) + + +def ref_expand_shape( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + # permute axis and do reshape + reassociation: gc_mlir.ir.ArrayAttr = op.attributes["reassociation"] + permutation: List[int] = [0] * len(op.result.type.shape) + shape: List[int] = [] + + d: int = 0 + for indim in reassociation: + for outdim in indim: + shape.append(int(op.result.type.shape[int(outdim)])) + permutation[int(outdim)] = d + d = d + 1 + return (torch.reshape(var[cache.opr[0]], shape).permute(permutation).contiguous(),) diff --git a/test/benchgc/src/benchgc/util.py b/test/benchgc/src/benchgc/util.py new file mode 100644 index 000000000..5f362bf00 --- /dev/null +++ b/test/benchgc/src/benchgc/util.py @@ -0,0 +1,341 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import operator +import random +from functools import reduce +from typing import Any, Callable, List, Tuple, Union + +import ml_dtypes +import numpy +import torch + +# verbose level +NO_VERBOSE = 0 +MODULE_VERBOSE = 1 # print the module will be executed +ARG_VERBOSE = 2 # + print arg information +COMPARE_VERBOSE = 3 # + print threshold for comparison +ERROR_OUTPUT_VERBOSE = 4 # + print all error data points if failed +OUTPUT_VERBOSE = 5 # + print all result including passed tensor +INPUT_VERBOSE = 6 # + print input torch tensors + +""" +acc | acc | elems | value_range | worst case +s32 | mul | 10 | 3 | 3^10=2^16, out of 2^30 (max integer) +f16 | mul | 10 | 1 | (2^1)^10=2^10, out of 2^16 (max exponent) +f32 | mul | 30 | 3 | (2^3)^30=2^90, out of 2^128 (max exponent) +s32 | sum | 10000 | 50 | 10000*50=2^19, out of 2^30 (max integer) +f16 | sum | 1000 | 8 | 1000*8=2^13, out of 2^10 (max mantissa/integer) +f32 | sum | 10000 | 16 | 10000*16=2^18, out of 2^23 (max mantissa/integer) + min/max | all | 1000 | no limits on accumulation chain + +In f16 cases, the worst case exceeds the data type bounds, however it's rare +to reach these extreme cases as long as they're close (can't just use f32 bounds) +""" +# first: nonneutral elements +# second: maximum range +_problem_bounds = { + "mul_int": (10, 3), + "mul_fp16": (10, 1), + "mul_fp32": (30, 3), + "sum_int": (10000, 50), + "sum_fp16": (1000, 8), + "sum_fp32": (10000, 16), + "minmax_int": (-1, 1000), + "minmax_fp": (-1, 1000), +} +_dtype_2_range = { + "f32": (-16777216, 16777216), + "f64": (-16777216, 16777216), + "f16": (-2048, 2048), + "bf16": (-16777216, 16777216), + "f8_e5m2": (-2048, 2048), + "f8_e4m3": (-2048, 2048), + "u8": (0, 255), + "s8": (-128, 127), + "s32": (-2147483648, 2147483520), +} + + +def flip_coin( + seed: Union[Any, torch.Tensor], prob: Union[float, torch.Tensor] +) -> Union[bool, torch.Tensor]: + big_prime: int = 1000003 + prime: int = 753737 + seed = seed * prime + return (seed % big_prime) < (prob * big_prime) + + +def get_problem_bounds(kind: str, dt: torch.dtype) -> Tuple[int, int]: + if not dt.is_floating_point: + if kind in ["max", "min"]: + return _problem_bounds["minmax_int"] + elif kind == "mul": + return _problem_bounds["mul_int"] + else: + return _problem_bounds["sum_int"] + elif kind in ["max", "min"]: + return _problem_bounds["minmax_fp"] + elif kind == "mul": + return ( + _problem_bounds["mul_fp16"] + if dt == torch.float16 + else _problem_bounds["mul_fp32"] + ) + else: + return ( + _problem_bounds["sum_fp16"] + if dt == torch.float16 + else _problem_bounds["sum_fp32"] + ) + + +def get_type_range(dt: str) -> Tuple[float, float]: + return _dtype_2_range[dt] + + +# Lnorm, Bnorm & Conv +def get_digits(dtype: str) -> int: + return { + "f32": 24, + "f64": 53, + "s8": 7, + "u8": 8, + "f16": 11, + "bf16": 8, + "f8_e5m2": 3, + "f8_e4m3": 4, + }[dtype] + + +def get_dtype(dtype: str) -> torch.dtype: + if dtype == "f32": + return torch.float32 + elif dtype == "f64": + return torch.float64 + elif dtype == "f16": + return torch.float16 + elif dtype == "bf16": + return torch.bfloat16 + elif dtype == "u8" or dtype == "ui8": + return torch.uint8 + elif dtype == "s8" or dtype == "i8": + return torch.int8 + elif dtype == "boolean": + return torch.uint8 + elif dtype == "f8_e4m3": + return torch.float8_e4m3fn + elif dtype == "f8_e5m2": + return torch.float8_e5m2 + elif dtype == "s32" or dtype == "i32": + return torch.int32 + else: + raise Exception(f"data type not support: {dtype}") + + +def tensor_to_ndarray(tensor: torch.Tensor) -> Any: + if tensor.dtype == torch.bfloat16: + return tensor.view(dtype=torch.uint16).numpy().view(ml_dtypes.bfloat16) + return tensor.numpy() + + +def get_eps(dtype: torch.dtype) -> float: + return torch.finfo(dtype).eps if dtype.is_floating_point else 0.0 + + +_seed: int = 0 + + +def set_seed(seed: int): + global _seed + _seed = seed + + +def torch_seed(seed_scale: int = 1, seed_shift: int = 0): + seed: int = _seed * seed_scale + seed_shift + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + numpy.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def iterate_tensor(tensor: torch.Tensor, fn: Callable[[Tuple[int, ...]], None]): + index: List[int] = [0] * tensor.ndim + + def dfs(depth: int): + if depth == tensor.ndim: + fn(tuple(index)) + else: + for i in range(tensor.shape[depth]): + index[depth] = i + dfs(depth + 1) + + +# indicate how to check the result +class Checker: + use_norm: bool + # set if negative result is trancated to zero + truncate_negative: bool + eltwise_relax: bool + threshold: float + zero_percent: float + # args: [ref, res, abs_diff, rel_diff] + customized_checker: ( + Callable[[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor] + | None + ) + + def __init__( + self, + threshold: float, + zero_percent: float, + use_norm: bool = False, + eltwise_relax: bool = False, + truncate_negative: bool = False, + checker: ( + Callable[ + [torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor + ] + | None + ) = None, + ) -> None: + self.use_norm = use_norm + self.eltwise_relax = eltwise_relax + self.threshold = threshold + self.zero_percent = zero_percent + self.truncate_negative = truncate_negative + self.customized_checker = checker + + def check( + self, ref: torch.Tensor, res: torch.Tensor, verbose: int + ) -> Tuple[bool, bool | None]: + if self.use_norm: + return self.norm(ref, res, verbose) + else: + return self.p2p(ref, res, verbose) + + def norm( + self, ref: torch.Tensor, res: torch.Tensor, verbose: int + ) -> Tuple[bool, bool | None]: + + f32_ref = ref.to(torch.float32) + f32_res = res.to(torch.float32) + if f32_ref.nelement() == 0: + return (True, None) + + diff_square_sum = torch.square(torch.subtract(f32_ref, f32_res)).sum() + square_sum = torch.square(f32_ref).sum() + + l2_diff_norm = torch.sqrt(diff_square_sum / square_sum).item() + if verbose >= COMPARE_VERBOSE: + print(f"norm check: {l2_diff_norm:.10f} / threshold: {self.threshold:.10f}") + + return (l2_diff_norm < self.threshold, None) + + def p2p( + self, ref: torch.Tensor, res: torch.Tensor, verbose: int + ) -> Tuple[bool, bool | None]: + + if verbose >= COMPARE_VERBOSE: + print(f"p2p check: threshold: {self.threshold:.7f}") + f32_ref = ref.to(torch.float32) + f32_res = res.to(torch.float32) + + check = torch.BoolTensor([False]) + + check = check.bitwise_or(torch.bitwise_and(f32_ref.isnan(), f32_res.isnan())) + check = check.bitwise_or( + torch.bitwise_and(f32_ref.isneginf(), f32_res.isneginf()) + ) + check = check.bitwise_or( + torch.bitwise_and(f32_ref.isposinf(), f32_res.isposinf()) + ) + + # choose diff/rel_diff based on value + abs_diff = (f32_ref - f32_res).abs() + rel_diff = abs_diff / torch.where( + f32_ref.abs() > numpy.finfo(numpy.float32).smallest_subnormal, + f32_ref.abs(), + 1, + ) + # pick a diff for comparison + diff = torch.where(f32_ref.abs() > 1e-5, rel_diff, abs_diff) + + check = check.bitwise_or(diff <= self.threshold) + + if self.eltwise_relax: + check = check.bitwise_or(abs_diff <= max(torch.finfo(res.dtype).eps, 2e-5)) + + if self.customized_checker is not None: + check = check.bitwise_or( + self.customized_checker(ref, res, abs_diff, rel_diff) + ) + + if verbose >= OUTPUT_VERBOSE: + iterate_tensor( + check, + lambda idx: print( + "%20s: ref: %12.7f res: %12.7f abs_diff: %12.7f rel_diff: %12.7f" + % ( + idx, + f32_ref[idx].item(), + f32_res[idx].item(), + abs_diff[idx].item(), + rel_diff[idx].item(), + ) + ), + ) + if check.all(): + # check mistrusted + zero = res.nelement() - res.count_nonzero().item() + if res.nelement() < 10: + mistrust = False + elif self.truncate_negative: + mistrust = ( + zero * 100.0 / res.nelement() > 50.0 + self.zero_percent / 2.0 + ) + else: + mistrust = zero * 100.0 / res.nelement() > self.zero_percent + return (True, mistrust) + else: + if ( + verbose < OUTPUT_VERBOSE + ): # skip verbose print if full output tensor is alrady printed + fail = torch.argwhere(torch.where(check, 0, 1)) + if verbose < ERROR_OUTPUT_VERBOSE: + fail = fail[ + :10 + ] # only print top 10 failed data points if verbose level does not satisfied + for idx in fail: + index: Tuple[int, ...] = tuple(idx.tolist()) + print( + "%20s: ref: %12.7f res: %12.7f abs_diff: %12.7f rel_diff: %12.7f" + % ( + index, + f32_ref[index].item(), + f32_res[index].item(), + abs_diff[index].item(), + rel_diff[index].item(), + ) + ) + return (False, None) + + +def nelem(shape: List[int]) -> int: + return reduce(operator.mul, shape) From 1f5b6ba88d2ae8b724a5238c4ce629d096559b74 Mon Sep 17 00:00:00 2001 From: "Xu, Rui" Date: Mon, 26 Aug 2024 01:42:49 -0700 Subject: [PATCH 02/38] merge code --- test/benchgc/CMakeLists.txt | 1 + test/benchgc/README.md | 25 +- test/benchgc/src/benchgc/__main__.py | 342 +++++++++++------- test/benchgc/src/benchgc/mlir/bench.py | 206 +++++++++++ test/benchgc/src/benchgc/mlir/util.py | 94 ++++- .../src/benchgc/pattern/CMakeLists.txt | 22 ++ test/benchgc/src/benchgc/pattern/__init__.py | 20 + test/benchgc/src/benchgc/pattern/mlp.py | 320 ++++++++++++++++ 8 files changed, 884 insertions(+), 146 deletions(-) create mode 100644 test/benchgc/src/benchgc/mlir/bench.py create mode 100644 test/benchgc/src/benchgc/pattern/CMakeLists.txt create mode 100644 test/benchgc/src/benchgc/pattern/__init__.py create mode 100644 test/benchgc/src/benchgc/pattern/mlp.py diff --git a/test/benchgc/CMakeLists.txt b/test/benchgc/CMakeLists.txt index e50f35cf2..d31895de3 100644 --- a/test/benchgc/CMakeLists.txt +++ b/test/benchgc/CMakeLists.txt @@ -39,3 +39,4 @@ add_subdirectory("src/benchgc/mlir") add_subdirectory("src/benchgc/linalg") add_subdirectory("src/benchgc/tensor") add_subdirectory("src/benchgc/arith") +add_subdirectory("src/benchgc/pattern") diff --git a/test/benchgc/README.md b/test/benchgc/README.md index 77499c5fd..ab63cb156 100644 --- a/test/benchgc/README.md +++ b/test/benchgc/README.md @@ -25,9 +25,13 @@ python -m pip install test/benchgc/dist/benchgc-*.whl ## Synopsis ``` -python -m benchgc [OPTIONS] --driver [DRIVER] --case [CASE] +python -m benchgc [OPTIONS] --mode [MODE] --driver [DRIVER] --case [CASE] ``` ## Flags +### --mode [str] +* C : correctness testing (by default) +* P : performance testing + ### --driver [str] * linalg: test the single op in linalg dialect * mlir: upload a mlir file and run @@ -97,7 +101,24 @@ module { | Norm check | N | threshold | | Benchdnn driver | D | driver_name:dtype:case | -## Example +### --pattern + + + + +## Perfermance testing flags +### --bench_kind [str] +* py +* wrapper + +### --warm_up [int] +* warm-up times of the execution + +### --repeat +* repeat times of the execution + +### Example +### Correctness testing example ``` # single add op test # using the same data filling / compare strategy as the benchdnn primitive driver if not set diff --git a/test/benchgc/src/benchgc/__main__.py b/test/benchgc/src/benchgc/__main__.py index 481cfcd91..655d2b7f2 100644 --- a/test/benchgc/src/benchgc/__main__.py +++ b/test/benchgc/src/benchgc/__main__.py @@ -32,10 +32,18 @@ ) from benchgc.arg.arg import Arg from benchgc.mlir.arg import get_mlir_args +from benchgc.pattern.mlp import MLP from gc_mlir.graph_compiler import GraphCompiler -try: - parser = argparse.ArgumentParser(prog="benchmark tool for graph compiler") +def add_common_options(parser: argparse.ArgumentParser): + parser.add_argument( + "--mode", + required=False, + help="specify the test mode, C for correctness testing, P for performance testing", + choices=["C", "P"], + default="C", + type=str, + ) parser.add_argument( "--driver", required=False, @@ -144,131 +152,215 @@ help="define the strides attribute in linalg op", type=int, ) - flags = parser.parse_args() - benchgc.util.set_seed(flags.seed) -except argparse.ArgumentError: - sys.stderr.write("Argument parse failed\n") - sys.exit(1) - -args: List[Arg] = [] - -if flags.driver == "mlir": - # we need to find all args by reading the entry function - with open(flags.case, "r") as mlir_file: - with gc_mlir.ir.Context() as ctx: - module = gc_mlir.ir.Module.parse(mlir_file.read()) - entry = benchgc.mlir.util.get_entry(module) - idx: int = 0 - # FIXME: only support RankTensorType now - for i in entry.type.inputs: - args.append(Arg(idx)) - args[-1].dtype = str(i.element_type) - args[-1].shape = list(i.shape) - args[-1].set_scalar() - idx += 1 - - for o in entry.type.results: - args.append(Arg(idx)) - args[-1].dtype = str(o.element_type) - args[-1].shape = list(o.shape) - args[-1].set_scalar() - idx += 1 -elif flags.driver in ["linalg"]: - # all arg shape/dt should be provided in single op test - for i in range(len(flags.md)): - args.append(Arg(i)) - - for md in flags.md: - colon = md.find(":") - if colon == -1: - raise Exception("Wrong md format: %s", md) - idx = int(md[:colon]) - args[idx].set_md(md[colon + 1 :]) - - from .linalg import mlir_op - - mlir_func = mlir_op[flags.case] - module = mlir_func(flags, args) -else: - raise Exception(f"unsupported driver {flags.driver}") - -for fill in flags.fill: - colon = fill.find(":") - if colon == -1: - raise Exception("Wrong fill format: %s", fill) - idx = int(fill[:colon]) - args[idx].set_fill(fill[colon + 1 :]) - -for cmp in flags.cmp: - colon = cmp.find(":") - if colon == -1: - raise Exception("Wrong cmp format: %s", cmp) - idx = int(cmp[:colon]) - args[idx].set_cmp(cmp[colon + 1 :]) - -entry = benchgc.mlir.util.get_entry(module) - -for i, arg in enumerate(args): - # use zero filling if the arg is return value - set_default_fill(flags, arg, args, i >= len(entry.type.inputs)) - set_default_compare(flags, arg, args, i >= len(entry.type.inputs)) - -for arg in args: - arg.print_verbose(flags.verbose) - -if flags.verbose >= benchgc.util.MODULE_VERBOSE: - print(module) - -ref_args: List[torch.Tensor] = [] -gc_args: List[torch.Tensor | int] = [] -ref_tensors: Dict[str, torch.Tensor] = {} -gc_tensors: Dict[str, torch.Tensor] = {} - -for i in range(len(args)): - tensor = fill_tensor(flags, args[i], i) - gc_tensors["%arg" + str(i)] = tensor - ref_tensors["%arg" + str(i)] = tensor.clone() - ref_args.append(ref_tensors["%arg" + str(i)]) - if args[i].scalar: - gc_args.append(tensor.data_ptr()) +def add_bench_options(parser: argparse.ArgumentParser): + ''' add options for bench mode''' + if parser.parse_known_args()[0].mode == "P": + parser.add_argument("-p", "--print_ir", + action="store_true", + help="if need print the IR after pipeline", + required=False + ) + parser.add_argument( + "--disable_results_to_params", + action="store_true", + default=False + ) + parser.add_argument( + "--bench_kind", + type=str, choices=["py", "wrapper"],default="py" + ) + parser.add_argument("--warm_up", type=int, default=100) + parser.add_argument("--repeat", type=int, default=100) + parser.add_argument("--entry", type=str, default="main_entry") + + +def add_pattern_options(parser: argparse.ArgumentParser): + '''add options for each pattern''' + if parser.parse_known_args()[0].driver == "pattern": + pattern_name = parser.parse_known_args()[0].driver + get_pattern_clz(pattern_name).add_args(parser) + + +def get_pattern_clz(diver_str: str): + """Function getting Pattern class by name.""" + clz = {"mlp": MLP}[diver_str] + return clz + +def get_moudle_and_args(flags): + '''get module and args''' + args: List[Arg] = [] + if flags.driver == "mlir": + # we need to find all args by reading the entry function + with open(flags.case, "r") as mlir_file: + with gc_mlir.ir.Context() as ctx: + module = gc_mlir.ir.Module.parse(mlir_file.read()) + entry = benchgc.mlir.util.get_entry(module) + idx: int = 0 + # FIXME: only support RankTensorType now + for i in entry.type.inputs: + args.append(Arg(idx)) + args[-1].dtype = str(i.element_type) + args[-1].shape = list(i.shape) + args[-1].set_scalar() + idx += 1 + + for o in entry.type.results: + args.append(Arg(idx)) + args[-1].dtype = str(o.element_type) + args[-1].shape = list(o.shape) + args[-1].set_scalar() + idx += 1 + elif flags.driver in ["linalg"]: + # all arg shape/dt should be provided in single op test + for i in range(len(flags.md)): + args.append(Arg(i)) + + for md in flags.md: + colon = md.find(":") + if colon == -1: + raise Exception("Wrong md format: %s", md) + idx = int(md[:colon]) + args[idx].set_md(md[colon + 1 :]) + + from .linalg import mlir_op + + mlir_func = mlir_op[flags.case] + module = mlir_func(flags, args) else: - gc_args.append(tensor) - - -# ref_out contains return value of the entry -ref_out = runner.ref_run(entry, ref_tensors) - -# we need to swap the result into the args if some arg is the return value -if ref_out is not None: - for i in range(len(ref_out)): - ref_args[0 - i - 1] = ref_out[0 - i - 1] + raise Exception(f"unsupported driver {flags.driver}") -entry = "entry" - -mlir_args = get_mlir_args(gc_args) -passes = "any(gc-cpu-pipeline)" - -with module.context: - compiler = GraphCompiler(passes) - engine = compiler.compile_and_jit(module) - engine.invoke(entry, *mlir_args) + for fill in flags.fill: + colon = fill.find(":") + if colon == -1: + raise Exception("Wrong fill format: %s", fill) + idx = int(fill[:colon]) + args[idx].set_fill(fill[colon + 1 :]) -fail, mistrust = False, False -for i in range(len(args)): - # gc_arg contains address for scalar value - # we need to find result by arg name - res = compare_tensor( - args[i], ref_args[i], gc_tensors["%arg" + str(i)], flags.verbose - ) - fail = fail or (not res[0]) - if res[1] is not None: - mistrust = mistrust | res[1] -if fail: - print(f"FAIL: {flags.driver}.{flags.case}") - sys.exit(1) -elif mistrust: - print(f"MISTRUST: {flags.driver}.{flags.case}") -else: - print(f"PASSED: {flags.driver}.{flags.case}") + for cmp in flags.cmp: + colon = cmp.find(":") + if colon == -1: + raise Exception("Wrong cmp format: %s", cmp) + idx = int(cmp[:colon]) + args[idx].set_cmp(cmp[colon + 1 :]) + + entry = benchgc.mlir.util.get_entry(module) + + for i, arg in enumerate(args): + # use zero filling if the arg is return value + set_default_fill(flags, arg, args, i >= len(entry.type.inputs)) + set_default_compare(flags, arg, args, i >= len(entry.type.inputs)) + + for arg in args: + arg.print_verbose(flags.verbose) + + if flags.verbose >= benchgc.util.MODULE_VERBOSE: + print(module) + return module, args + +def correctness_testing(flags, module, args): + ref_args: List[torch.Tensor] = [] + gc_args: List[torch.Tensor | int] = [] + ref_tensors: Dict[str, torch.Tensor] = {} + gc_tensors: Dict[str, torch.Tensor] = {} + + for i in range(len(args)): + tensor = fill_tensor(flags, args[i], i) + gc_tensors["%arg" + str(i)] = tensor + ref_tensors["%arg" + str(i)] = tensor.clone() + ref_args.append(ref_tensors["%arg" + str(i)]) + if args[i].scalar: + gc_args.append(tensor.data_ptr()) + else: + gc_args.append(tensor) + + + entry = "entry" + # ref_out contains return value of the entry + ref_out = runner.ref_run(entry, ref_tensors) + + # we need to swap the result into the args if some arg is the return value + if ref_out is not None: + for i in range(len(ref_out)): + ref_args[0 - i - 1] = ref_out[0 - i - 1] + + + + mlir_args = get_mlir_args(gc_args) + passes = "any(gc-cpu-pipeline)" + + with module.context: + compiler = GraphCompiler(passes) + engine = compiler.compile_and_jit(module) + engine.invoke(entry, *mlir_args) + + fail, mistrust = False, False + for i in range(len(args)): + # gc_arg contains address for scalar value + # we need to find result by arg name + res = compare_tensor( + args[i], ref_args[i], gc_tensors["%arg" + str(i)], flags.verbose + ) + fail = fail or (not res[0]) + if res[1] is not None: + mistrust = mistrust | res[1] + if fail: + print(f"FAIL: {flags.driver}.{flags.case}") + sys.exit(1) + elif mistrust: + print(f"MISTRUST: {flags.driver}.{flags.case}") + else: + print(f"PASSED: {flags.driver}.{flags.case}") + + +def performance_testing(flags, module, args): + gc_args: List[torch.Tensor | int] = [] + gc_tensors: Dict[str, torch.Tensor] = {} + for i in range(len(args)): + tensor = fill_tensor(flags, args[i], i) + gc_tensors["%arg" + str(i)] = tensor + if args[i].scalar: + gc_args.append(tensor.data_ptr()) + else: + gc_args.append(tensor) + + mlir_args = get_mlir_args(gc_args) + with module.context as ctx: + if flags.print_ir: + ctx.enable_multithreading(False) + bench_kind = py_timeit_bench if flags.bench_kind == "py" else mlir_wrapper_bench + execute_cost, compile_cost = bench_kind( + module, + "entry", + "any(gc-cpu-pipeline)", + mlir_args, + flags.print_ir, + flags.repeat, + flags.warm_up, + ) + print("===========bench result===========") + json_res = json.dumps( + { + "args": vars(flags), + "compile_cost(ms)": compile_cost, + "execute_cost(ms)": execute_cost, + }, + indent=4, + ) + print(json_res) + +if __name__ == "__main__": + arg_parser = argparse.ArgumentParser(prog="benchmark tool for graph compiler") + add_common_options(arg_parser) + add_bench_options(arg_parser) + add_pattern_options(arg_parser) + flags = arg_parser.parse_args() + benchgc.util.set_seed(flags.seed) + ir_module, module_args = get_moudle_and_args(flags) + if flags.mode == "C": + correctness_testing(flags, ir_module, module_args) + elif flags.mode == "P": + performance_testing(flags, ir_module, module_args) + else: + pass \ No newline at end of file diff --git a/test/benchgc/src/benchgc/mlir/bench.py b/test/benchgc/src/benchgc/mlir/bench.py new file mode 100644 index 000000000..b005775a4 --- /dev/null +++ b/test/benchgc/src/benchgc/mlir/bench.py @@ -0,0 +1,206 @@ +################################################################################ +# Copyright (C) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# SPDX-License-Identifier: Apache-2.0 +################################################################################ + +import ctypes +import random +import timeit +from typing import List, Sequence, Tuple + +import numpy as np +from gc_mlir import ir, runtime +from gc_mlir.graph_compiler import GraphCompiler +from benchgc.mlir.util import ( + emit_benchmark_wrapped_main_func, + emit_nano_time, + get_kernel_func_from_module, +) + + +def py_timeit_bench( + ir_module: ir.Module, + entry_name: str, + pipeline: str, + mlir_args: list, + ir_printing=False, + repeat_time=100, + warm_up=20, +) -> Tuple[float, float]: + """benchmark mlir with python timeit.""" + compiler = GraphCompiler(pipeline) + compile_begin = timeit.default_timer() + engine = compiler.compile_and_jit(ir_module, ir_printing=ir_printing) + compile_cost = (timeit.default_timer() - compile_begin) * 1000 + + # Copied from execution_engine.py so that the cost of cast does not affect perf result. + func = engine.lookup(entry_name) + packed_args = (ctypes.c_void_p * len(mlir_args))() + for argNum in range(len(mlir_args)): + packed_args[argNum] = ctypes.cast(mlir_args[argNum], ctypes.c_void_p) + + def run_bench(func, arg): + func(arg) + + timeit.timeit(lambda: run_bench(func, packed_args), number=warm_up) + total_time = timeit.timeit(lambda: run_bench(func, packed_args), number=repeat_time) + execute_cost = total_time * 1000 / repeat_time + return (execute_cost, compile_cost) + + +def mlir_wrapper_bench( + ir_module: ir.Module, + entry_name: str, + pipeline: str, + mlir_args: list, + ir_printing=False, + repeat_time=100, + warm_up=20, +) -> Tuple[float, float]: + """benchmark mlir with a wrapper func.""" + kernel_func = get_kernel_func_from_module(ir_module, entry_name) + wrapper_module = ir_module + with ir.InsertionPoint(wrapper_module.body): + emit_benchmark_wrapped_main_func(kernel_func, emit_nano_time()) + compiler = GraphCompiler(pipeline) + compile_begin = timeit.default_timer() + engine = compiler.compile_and_jit(wrapper_module, ir_printing=ir_printing) + compile_cost = (timeit.default_timer() - compile_begin) * 1000 + + np_timers_ns = np.array([0], dtype=np.int64) + time_arg = ctypes.pointer( + ctypes.pointer(runtime.get_ranked_memref_descriptor(np_timers_ns)) + ) + total_time = 0 + ns_to_ms_scale = 1e-6 + + def run(engine_invoke, bench_func_name, *mlir_args): + engine_invoke(bench_func_name, *mlir_args) + + for i in range(repeat_time + warm_up): + run(engine.invoke, "wrapped_main", time_arg, *mlir_args) + if i >= warm_up: + total_time += int(np_timers_ns[0]) * ns_to_ms_scale + execute_cost = total_time / repeat_time + return (execute_cost, compile_cost) + + +# for test +def fake_bench( + ir_module: ir.Module = None, + entry_name: str = None, + pipeline: str = None, + mlir_args: list = None, + ir_printing=False, + repeat_time=100, + warm_up=20, +) -> Tuple[float, float]: + """genrate fake benchmark result.""" + execute_cost = float(random.randint(1, 100)) + compile_cost = float(random.randint(1, 100)) + return (execute_cost, compile_cost) + + +def batch_py_timeit_bench( + ir_modules: List[ir.Module], + entry_name: str, + pipeline: str, + mlir_args: list, + ir_printing=False, + repeat_time=5, + warm_up=2, +) -> List[Tuple[float, float]]: + """benchmark a batch of mlir with python timeit.""" + compiler = GraphCompiler(pipeline) + funcs = [] + compile_costs = [] + for m in ir_modules: + compile_begin = timeit.default_timer() + engine = compiler.compile_and_jit(m, ir_printing=ir_printing) + compile_cost = (timeit.default_timer() - compile_begin) * 1000 + compile_costs.append(compile_cost) + funcs.append(engine.lookup(entry_name)) + + # Copied from execution_engine.py so that the cost of cast does not affect perf result. + packed_args = (ctypes.c_void_p * len(mlir_args))() + for argNum in range(len(mlir_args)): + packed_args[argNum] = ctypes.cast(mlir_args[argNum], ctypes.c_void_p) + + def run_bench(func, arg): + func(arg) + + for func in funcs: + timeit.timeit(lambda: run_bench(func, packed_args), number=warm_up) + + execute_costs = [] + for func in funcs: + total_time = timeit.timeit( + lambda: run_bench(func, packed_args), number=repeat_time + ) + execute_cost = total_time * 1000 / repeat_time + execute_costs.append(execute_cost) + return list(zip(compile_costs, execute_costs)) + + +def batch_mlir_wrapper_bench( + ir_modules: ir.Module, + entry_name: str, + pipeline: str, + mlir_args: list, + ir_printing=False, + repeat_time=5, + warm_up=2, +) -> Tuple[float, float]: + """benchmark a batch of mlir with wrapper func.""" + compiler = GraphCompiler(pipeline) + + engine_invokes = [] + compile_costs = [] + for m in ir_modules: + kernel_func = get_kernel_func_from_module(m, entry_name) + wrapper_module = m + with ir.InsertionPoint(wrapper_module.body): + emit_benchmark_wrapped_main_func(kernel_func, emit_nano_time()) + compile_begin = timeit.default_timer() + engine = compiler.compile_and_jit(wrapper_module, ir_printing=ir_printing) + compile_cost = (timeit.default_timer() - compile_begin) * 1000 + compile_costs.append(compile_cost) + engine_invokes.append(engine.invoke) + + np_timers_ns = np.array([0], dtype=np.int64) + time_arg = ctypes.pointer( + ctypes.pointer(runtime.get_ranked_memref_descriptor(np_timers_ns)) + ) + total_time = 0 + ns_to_ms_scale = 1e-6 + + def run(engine_invoke, bench_func_name, *mlir_args): + engine_invoke(bench_func_name, *mlir_args) + + for engine_invoke in engine_invokes: + for _ in range(warm_up): + run(engine_invoke, "wrapped_main", time_arg, *mlir_args) + + execute_costs = [] + for engine_invoke in engine_invokes: + total_time = 0 + for _ in range(repeat_time): + run(engine_invoke, "wrapped_main", time_arg, *mlir_args) + total_time += int(np_timers_ns[0]) * ns_to_ms_scale + + execute_cost = total_time / repeat_time + execute_costs.append(execute_cost) + + return list(zip(compile_costs, execute_costs)) \ No newline at end of file diff --git a/test/benchgc/src/benchgc/mlir/util.py b/test/benchgc/src/benchgc/mlir/util.py index 24169bca1..2213139b2 100644 --- a/test/benchgc/src/benchgc/mlir/util.py +++ b/test/benchgc/src/benchgc/mlir/util.py @@ -17,15 +17,15 @@ import ctypes from typing import Any, List -import gc_mlir.ir import torch -from gc_mlir.dialects import func +from gc_mlir.dialects import func, arith, memref +from gc_mlir import ir # only python 3.11 support # from typing import Self -def get_entry(module: gc_mlir.ir.Module, entry: str = '"entry"') -> func.FuncOp: +def get_entry(module: ir.Module, entry: str = '"entry"') -> func.FuncOp: for op in module.operation.opview.regions[0].blocks[0].operations: if str(op.name) == entry: return op @@ -72,40 +72,96 @@ def dtype_to_ctype(dtype: torch.dtype): raise ValueError(f"Unsupported torch dtype: {dtype}") -def str_to_mlir_dtype(ctx: gc_mlir.ir.Context, dtype: str) -> gc_mlir.ir.Type: +def str_to_mlir_dtype(ctx: ir.Context, dtype: str) -> ir.Type: if dtype == "f32": - return gc_mlir.ir.F32Type.get(ctx) + return ir.F32Type.get(ctx) elif dtype == "f64": - return gc_mlir.ir.F64Type.get(ctx) + return ir.F64Type.get(ctx) elif dtype == "f16": - return gc_mlir.ir.F16Type.get(ctx) + return ir.F16Type.get(ctx) elif dtype == "bf16": - return gc_mlir.ir.BF16Type.get(ctx) + return ir.BF16Type.get(ctx) elif dtype == "u8": - return gc_mlir.ir.IntegerType.get_unsigned(8, ctx) + return ir.IntegerType.get_unsigned(8, ctx) elif dtype == "s8": - return gc_mlir.ir.IntegerType.get_signed(8, ctx) + return ir.IntegerType.get_signed(8, ctx) elif dtype == "boolean": - return gc_mlir.ir.IntegerType.get_unsigned(1, ctx) + return ir.IntegerType.get_unsigned(1, ctx) elif dtype == "f8_e4m3": - return gc_mlir.ir.Float8E4M3FNType.get(ctx) + return ir.Float8E4M3FNType.get(ctx) elif dtype == "f8_e5m2": - return gc_mlir.ir.Float8E5M2Type.get(ctx) + return ir.Float8E5M2Type.get(ctx) elif dtype == "s32": - return gc_mlir.ir.IntegerType.get_signed(32, ctx) + return ir.IntegerType.get_signed(32, ctx) else: raise Exception(f"data type not support: {dtype}") def str_to_mlir_typed_attr( - ctx: gc_mlir.ir.Context, dtype: str, value: Any -) -> gc_mlir.ir.Attribute: + ctx: ir.Context, dtype: str, value: Any +) -> ir.Attribute: mlir_dtype = str_to_mlir_dtype(ctx, dtype) if dtype in ["f32", "f64", "bf16", "f16", "f8_e4m3", "f8_e5m2"]: - return gc_mlir.ir.FloatAttr.get(mlir_dtype, value) + return ir.FloatAttr.get(mlir_dtype, value) elif dtype in ["u8", "s8", "s32"]: - return gc_mlir.ir.IntegerAttr.get(mlir_dtype, value) + return ir.IntegerAttr.get(mlir_dtype, value) elif dtype == "boolean": - return gc_mlir.ir.BoolAttr.get(value) + return ir.BoolAttr.get(value) else: raise Exception(f"data type not support: {dtype}") + + +def emit_benchmark_wrapped_main_func( + kernel_func: func.FuncOp, timer_func: func.FuncOp +) -> func.FuncOp: + """Emit a wrapped main function that calls the kernel function and records the time taken.""" + memref_of_i64_type = ir.MemRefType.get([1], ir.IntegerType.get_signless(64)) + wrapped_func_name = "wrapped_main" + assert wrapped_func_name != str( + kernel_func.name + ), "wrapped function name should be different from kernel function name" + wrapped_func = func.FuncOp( + wrapped_func_name, + ([memref_of_i64_type] + kernel_func.arguments.types, kernel_func.type.results), + visibility="public", + ) + wrapped_func.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + with ir.InsertionPoint(wrapped_func.add_entry_block()): + timer_buffer = wrapped_func.arguments[0] + start = func.CallOp(timer_func, []) + call_op = func.CallOp( + kernel_func, + list(wrapped_func.arguments[1:]), + ) + end = func.CallOp(timer_func, []) + time_taken = arith.SubIOp(end, start) + zero = arith.ConstantOp.create_index(0) + memref.StoreOp(time_taken, timer_buffer, [zero]) + func.ReturnOp(call_op.results) + return wrapped_func + + + +def emit_nano_time() -> func.FuncOp: + """Emit a nanoTime function that returns the current time in nanoseconds.""" + nanoTime = func.FuncOp( + "nanoTime", ([], [ir.IntegerType.get_signless(64)]), visibility="private" + ) + nanoTime.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + return nanoTime + + +def get_kernel_func_from_module( + module: ir.Module, func_name: str = "main_entry" +) -> func.FuncOp: + """Get the func op by the name from a module""" + assert ( + len(module.operation.regions) == 1 + ), "Expected kernel module to have only one region" + assert ( + len(module.operation.regions[0].blocks) == 1 + ), "Expected kernel module to have only one block" + for f in module.operation.regions[0].blocks[0].operations: + if type(f) is func.FuncOp and str(f.name).strip('"') == func_name: + return f + raise ValueError("can not find the entry function") \ No newline at end of file diff --git a/test/benchgc/src/benchgc/pattern/CMakeLists.txt b/test/benchgc/src/benchgc/pattern/CMakeLists.txt new file mode 100644 index 000000000..51683ac19 --- /dev/null +++ b/test/benchgc/src/benchgc/pattern/CMakeLists.txt @@ -0,0 +1,22 @@ +################################################################################ +# Copyright (C) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# SPDX-License-Identifier: Apache-2.0 +################################################################################ + + +file(GLOB PYTHON_SCRIPTS "*.py") +foreach(PY_SCRIPT ${PYTHON_SCRIPTS}) + configure_file(${PY_SCRIPT} ${CMAKE_BINARY_DIR}/test/benchgc/src/benchgc/pattern/ COPYONLY) +endforeach() \ No newline at end of file diff --git a/test/benchgc/src/benchgc/pattern/__init__.py b/test/benchgc/src/benchgc/pattern/__init__.py new file mode 100644 index 000000000..e97242f09 --- /dev/null +++ b/test/benchgc/src/benchgc/pattern/__init__.py @@ -0,0 +1,20 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import sys +import pathlib + +sys.path.append(pathlib.Path(__file__).parent.resolve().__str__()) \ No newline at end of file diff --git a/test/benchgc/src/benchgc/pattern/mlp.py b/test/benchgc/src/benchgc/pattern/mlp.py new file mode 100644 index 000000000..d5cf38ce1 --- /dev/null +++ b/test/benchgc/src/benchgc/pattern/mlp.py @@ -0,0 +1,320 @@ +################################################################################ +# Copyright (C) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# SPDX-License-Identifier: Apache-2.0 +################################################################################ + +import argparse +from abc import ABC, abstractmethod +from typing import List + +import numpy as np +from gc_mlir import ir +from gc_mlir.dialects import arith, func, linalg, tensor +from gc_mlir.ir import BF16Type, FloatAttr +from benchgc.mlir.util import ( + str_to_mlir_dtype, + get_kernel_func_from_module, +) + + +def to_int_list(s: str) -> List[int]: + """ + Parsing the cmd for list of int values + + Args: + s (str): int values in cmd, example: 2x3x4 + + Returns: + List[int]: int values in list, example: [2, 3, 4] + """ + if not s or len(s) == 0: + return [] + return [int(i) for i in s.strip().split("x")] + + +def to_bool_list(s: str) -> List[bool]: + """ + Parsing the cmd for list of bool values + + Args: + s (str): bools in cmd, example: 1x0x1 + + Returns: + List[bool]: bools in list, example: [True, False, True] + """ + if not s or len(s) == 0: + return [] + return [bool(int(i)) for i in s.strip().split("x")] + +class Pattern(ABC): + """Abstract class for driver.""" + + @staticmethod + @abstractmethod + def add_args(parser: argparse.ArgumentParser): + """Add arguments to parser""" + + @abstractmethod + def handle_args(self, args: argparse.Namespace): + """Get and handle the args""" + + def __init__(self, ctx: ir.Context, args: argparse.Namespace): + self.main_entry = "main_entry" + self.handle_args(args) + self.ir_module = self.init_module(ctx) + + @abstractmethod + def init_module(self, ctx: ir.Context) -> ir.Module: + """Create MLIR moudule by args""" + +class MLP(Pattern): + @staticmethod + def add_args(parser: argparse.ArgumentParser): + parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument("--hidden_size_list", type=str, default="") + parser.add_argument("--has_bias", required=False, type=str) + parser.add_argument( + "--act_type", type=str, choices=["noop", "relu", "sigmoid"], default="noop" + ) + parser.add_argument( + "--dtype", + type=str, + choices=[ + "f32", + "bf16", + ], + default="f32", + ) + + def handle_args(self, args: argparse.Namespace): + self.batch_size = args.batch_size + assert self.batch_size > 0, "batch size should be greater than 0" + + self.hidden_size_list = to_int_list(args.hidden_size_list) + layers = len(self.hidden_size_list) - 1 + assert layers >= 1, "hidden_size_list should have at least 2 elements" + + self.has_bias = ( + [False] * layers if args.has_bias is None else to_bool_list(args.has_bias) + ) + + assert ( + len(self.has_bias) == layers + ), "has_bias should have the same length as hidden_size_list" + + self.act_type = args.act_type + self.dtype = args.dtype + + def init_module(self, ctx: ir.Context) -> ir.Module: + with ctx, ir.Location.unknown(): + layers = len(self.hidden_size_list) - 1 + module = ir.Module.create() + dtype = str_to_mlir_dtype(self.dtype, ctx) + src = ir.RankedTensorType.get( + [self.batch_size, self.hidden_size_list[0]], dtype + ) + weights = [] + bias = [] + for i in range(layers): + weights.append( + ir.RankedTensorType.get( + [ + self.hidden_size_list[i], + self.hidden_size_list[i + 1], + ], + dtype, + ) + ) + if self.has_bias[i]: + bias.append( + ir.RankedTensorType.get([self.hidden_size_list[i + 1]], dtype) + ) + result = ir.RankedTensorType.get( + [ + self.batch_size, + self.hidden_size_list[-1], + ], + dtype, + ) + with ir.InsertionPoint(module.body): + f = func.FuncOp( + name=self.main_entry, + type=ir.FunctionType.get( + inputs=[src] + weights + bias, results=[result] + ), + ) + f.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + with ir.InsertionPoint(f.add_entry_block()): + data = f.entry_block.arguments[0] + bias_idx = len(weights) + 1 + for i in range(layers): + weight = f.entry_block.arguments[i + 1] + if self.has_bias[i]: + bias = f.entry_block.arguments[bias_idx] + bias_idx += 1 + else: + bias = None + layer_out_shape = [ + self.batch_size, + self.hidden_size_list[i + 1], + ] + + data = linalg.matmul( + data, weight, outs=[tensor.EmptyOp(layer_out_shape, dtype)] + ) + if bias: + broadcast_bias = linalg.broadcast( + bias, + outs=[tensor.EmptyOp(layer_out_shape, dtype)], + dimensions=[0], + ) + data = linalg.add( + data, + broadcast_bias, + outs=[tensor.EmptyOp(layer_out_shape, dtype)], + ) + + if self.act_type == "relu": + element = FloatAttr.get(dtype, 0) + tensor_type = ir.RankedTensorType.get( + layer_out_shape, dtype + ) + attr = ir.DenseElementsAttr.get_splat(tensor_type, element) + cst = arith.ConstantOp(tensor_type, attr) + data = linalg.max( + data, cst, outs=[tensor.EmptyOp(layer_out_shape, dtype)] + ) + func.ReturnOp([data]) + return module + + + @staticmethod + def add_args(parser: argparse.ArgumentParser): + parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument("--hidden_size_list", type=str, default="") + parser.add_argument("--has_bias", required=False, type=str) + parser.add_argument( + "--act_type", type=str, choices=["noop", "relu", "sigmoid"], default="noop" + ) + parser.add_argument( + "--dtype", + type=str, + choices=[ + "f32", + "bf16", + ], + default="f32", + ) + + def handle_args(self, args: argparse.Namespace): + self.batch_size = args.batch_size + assert self.batch_size > 0, "batch size should be greater than 0" + + self.hidden_size_list = to_int_list(args.hidden_size_list) + layers = len(self.hidden_size_list) - 1 + assert layers >= 1, "hidden_size_list should have at least 2 elements" + + self.has_bias = ( + [False] * layers if args.has_bias is None else to_bool_list(args.has_bias) + ) + + assert ( + len(self.has_bias) == layers + ), "has_bias should have the same length as hidden_size_list" + + self.act_type = args.act_type + self.dtype = args.dtype + + def init_module(self, ctx: ir.Context) -> ir.Module: + with ctx, ir.Location.unknown(): + layers = len(self.hidden_size_list) - 1 + module = ir.Module.create() + dtype = str_to_mlir_dtype(self.dtype, ctx) + src = ir.RankedTensorType.get( + [self.batch_size, self.hidden_size_list[0]], dtype + ) + weights = [] + bias = [] + for i in range(layers): + weights.append( + ir.RankedTensorType.get( + [ + self.hidden_size_list[i], + self.hidden_size_list[i + 1], + ], + dtype, + ) + ) + if self.has_bias[i]: + bias.append( + ir.RankedTensorType.get([self.hidden_size_list[i + 1]], dtype) + ) + result = ir.RankedTensorType.get( + [ + self.batch_size, + self.hidden_size_list[-1], + ], + dtype, + ) + with ir.InsertionPoint(module.body): + f = func.FuncOp( + name=self.main_entry, + type=ir.FunctionType.get( + inputs=[src] + weights + bias, results=[result] + ), + ) + f.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + with ir.InsertionPoint(f.add_entry_block()): + data = f.entry_block.arguments[0] + bias_idx = len(weights) + 1 + for i in range(layers): + weight = f.entry_block.arguments[i + 1] + if self.has_bias[i]: + bias = f.entry_block.arguments[bias_idx] + bias_idx += 1 + else: + bias = None + layer_out_shape = [ + self.batch_size, + self.hidden_size_list[i + 1], + ] + + data = linalg.matmul( + data, weight, outs=[tensor.EmptyOp(layer_out_shape, dtype)] + ) + if bias: + broadcast_bias = linalg.broadcast( + bias, + outs=[tensor.EmptyOp(layer_out_shape, dtype)], + dimensions=[0], + ) + data = linalg.add( + data, + broadcast_bias, + outs=[tensor.EmptyOp(layer_out_shape, dtype)], + ) + + if self.act_type == "relu": + element = FloatAttr.get(dtype, 0) + tensor_type = ir.RankedTensorType.get( + layer_out_shape, dtype + ) + attr = ir.DenseElementsAttr.get_splat(tensor_type, element) + cst = arith.ConstantOp(tensor_type, attr) + data = linalg.max( + data, cst, outs=[tensor.EmptyOp(layer_out_shape, dtype)] + ) + func.ReturnOp([data]) + return module \ No newline at end of file From 2cccd042df4dbb2f72aa3611728dd9167b6c01e5 Mon Sep 17 00:00:00 2001 From: Jialei A Wang Date: Fri, 16 Aug 2024 01:31:49 +0000 Subject: [PATCH 03/38] introduce benchgc for correctness check --- .github/workflows/build.yml | 12 + .github/workflows/style.yml | 21 + .gitignore | 1 + CMakeLists.txt | 1 + scripts/correctness.sh | 106 +++ test/CMakeLists.txt | 4 + test/benchgc/.gitignore | 5 + test/benchgc/CMakeLists.txt | 41 + test/benchgc/README.md | 257 ++++++ test/benchgc/cases/generic.mlir | 15 + test/benchgc/cases/llama2.mlir | 113 +++ test/benchgc/cases/reduce.mlir | 12 + test/benchgc/setup.py | 30 + test/benchgc/src/benchgc/CMakeLists.txt | 22 + test/benchgc/src/benchgc/__init__.py | 20 + test/benchgc/src/benchgc/__main__.py | 272 ++++++ test/benchgc/src/benchgc/arg/CMakeLists.txt | 22 + test/benchgc/src/benchgc/arg/__init__.py | 163 ++++ test/benchgc/src/benchgc/arg/arg.py | 54 ++ test/benchgc/src/benchgc/arg/binary.py | 98 ++ test/benchgc/src/benchgc/arg/compare.py | 137 +++ test/benchgc/src/benchgc/arg/conv.py | 189 ++++ test/benchgc/src/benchgc/arg/eltwise.py | 177 ++++ test/benchgc/src/benchgc/arg/matmul.py | 173 ++++ test/benchgc/src/benchgc/arg/pool.py | 97 ++ test/benchgc/src/benchgc/arg/reduce.py | 78 ++ test/benchgc/src/benchgc/arg/softmax.py | 93 ++ test/benchgc/src/benchgc/arith/CMakeLists.txt | 22 + test/benchgc/src/benchgc/arith/__init__.py | 45 + test/benchgc/src/benchgc/arith/basic.py | 58 ++ .../benchgc/src/benchgc/linalg/CMakeLists.txt | 22 + test/benchgc/src/benchgc/linalg/__init__.py | 52 ++ test/benchgc/src/benchgc/linalg/binary.py | 137 +++ test/benchgc/src/benchgc/linalg/conv.py | 834 ++++++++++++++++++ test/benchgc/src/benchgc/linalg/eltwise.py | 197 +++++ test/benchgc/src/benchgc/linalg/generic.py | 236 +++++ test/benchgc/src/benchgc/linalg/matmul.py | 317 +++++++ test/benchgc/src/benchgc/linalg/misc.py | 97 ++ test/benchgc/src/benchgc/linalg/pool.py | 489 ++++++++++ test/benchgc/src/benchgc/linalg/softmax.py | 47 + test/benchgc/src/benchgc/mlir/CMakeLists.txt | 22 + test/benchgc/src/benchgc/mlir/__init__.py | 15 + test/benchgc/src/benchgc/mlir/arg.py | 174 ++++ test/benchgc/src/benchgc/mlir/module.py | 48 + test/benchgc/src/benchgc/mlir/util.py | 111 +++ test/benchgc/src/benchgc/runner.py | 109 +++ .../benchgc/src/benchgc/tensor/CMakeLists.txt | 22 + test/benchgc/src/benchgc/tensor/__init__.py | 45 + test/benchgc/src/benchgc/tensor/basic.py | 33 + test/benchgc/src/benchgc/tensor/shape.py | 59 ++ test/benchgc/src/benchgc/util.py | 334 +++++++ 51 files changed, 5738 insertions(+) create mode 100755 scripts/correctness.sh create mode 100644 test/benchgc/.gitignore create mode 100644 test/benchgc/CMakeLists.txt create mode 100644 test/benchgc/README.md create mode 100644 test/benchgc/cases/generic.mlir create mode 100644 test/benchgc/cases/llama2.mlir create mode 100644 test/benchgc/cases/reduce.mlir create mode 100644 test/benchgc/setup.py create mode 100644 test/benchgc/src/benchgc/CMakeLists.txt create mode 100644 test/benchgc/src/benchgc/__init__.py create mode 100644 test/benchgc/src/benchgc/__main__.py create mode 100644 test/benchgc/src/benchgc/arg/CMakeLists.txt create mode 100644 test/benchgc/src/benchgc/arg/__init__.py create mode 100644 test/benchgc/src/benchgc/arg/arg.py create mode 100644 test/benchgc/src/benchgc/arg/binary.py create mode 100644 test/benchgc/src/benchgc/arg/compare.py create mode 100644 test/benchgc/src/benchgc/arg/conv.py create mode 100644 test/benchgc/src/benchgc/arg/eltwise.py create mode 100644 test/benchgc/src/benchgc/arg/matmul.py create mode 100644 test/benchgc/src/benchgc/arg/pool.py create mode 100644 test/benchgc/src/benchgc/arg/reduce.py create mode 100644 test/benchgc/src/benchgc/arg/softmax.py create mode 100644 test/benchgc/src/benchgc/arith/CMakeLists.txt create mode 100644 test/benchgc/src/benchgc/arith/__init__.py create mode 100644 test/benchgc/src/benchgc/arith/basic.py create mode 100644 test/benchgc/src/benchgc/linalg/CMakeLists.txt create mode 100644 test/benchgc/src/benchgc/linalg/__init__.py create mode 100644 test/benchgc/src/benchgc/linalg/binary.py create mode 100644 test/benchgc/src/benchgc/linalg/conv.py create mode 100644 test/benchgc/src/benchgc/linalg/eltwise.py create mode 100644 test/benchgc/src/benchgc/linalg/generic.py create mode 100644 test/benchgc/src/benchgc/linalg/matmul.py create mode 100644 test/benchgc/src/benchgc/linalg/misc.py create mode 100644 test/benchgc/src/benchgc/linalg/pool.py create mode 100644 test/benchgc/src/benchgc/linalg/softmax.py create mode 100644 test/benchgc/src/benchgc/mlir/CMakeLists.txt create mode 100644 test/benchgc/src/benchgc/mlir/__init__.py create mode 100644 test/benchgc/src/benchgc/mlir/arg.py create mode 100644 test/benchgc/src/benchgc/mlir/module.py create mode 100644 test/benchgc/src/benchgc/mlir/util.py create mode 100644 test/benchgc/src/benchgc/runner.py create mode 100644 test/benchgc/src/benchgc/tensor/CMakeLists.txt create mode 100644 test/benchgc/src/benchgc/tensor/__init__.py create mode 100644 test/benchgc/src/benchgc/tensor/basic.py create mode 100644 test/benchgc/src/benchgc/tensor/shape.py create mode 100644 test/benchgc/src/benchgc/util.py diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 102d3906a..0ca233a76 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -45,3 +45,15 @@ jobs: - name: Test run: | cmake --build build --target gc-check + + - name: Build and install benchgc + working-directory: build + run: | + ninja benchgc + pip uninstall -y benchgc || true + pip install test/benchgc/dist/benchgc-*.whl + - name: Correctness Test + env: + LD_PRELOAD: /lib/x86_64-linux-gnu/libomp5.so + run: | + scripts/correctness.sh \ No newline at end of file diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index 4c66e1f14..91efdebf6 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -22,3 +22,24 @@ jobs: run: | clang-format --version find . -name *.cpp -or -name *.hpp | xargs clang-format --dry-run --Werror -style=file + + python_format: + runs-on: ubuntu-latest + steps: + - name: checkout base version + uses: actions/checkout@v4 + with: + fetch-depth: 100 + ref: ${{ github.event.pull_request.base.sha }} + + - name: checkout head version + uses: actions/checkout@v4 + with: + fetch-depth: 100 + ref: ${{ github.event.pull_request.head.sha }} + + - name: install darker + run: "python3 -m pip install darker darker[isort] darker[flynt]" + + - name: check python format + run: "python3 -m darker --check -i -f --diff -r `git merge-base ${{ github.event.pull_request.base.sha }} ${{ github.event.pull_request.head.sha }}`...HEAD ." \ No newline at end of file diff --git a/.gitignore b/.gitignore index e1fe789da..40c724cd9 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ build/ externals/ compile_commands.json +__pycache__ \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 90a89666f..d9dcd7313 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -49,6 +49,7 @@ option(GC_ENABLE_IMEX "Enable Intel® Extension for MLIR" OFF) option(GC_ENABLE_BINDINGS_PYTHON "Enable Graph Complier Python Binding" ON) option(GC_DEV_LINK_LLVM_DYLIB "Link dynamic libraries of LLVM and MLIR. For developers only. Do not use it in packing the library." OFF) option(GC_ENABLE_RUNTIME_NAIVE_BRGEMM "Use naive BRGEMM as runtime backend for debug purpose." OFF) +option(GC_BENCH_ENABLE "Build benchgc." ON) if(GC_ENABLE_LEGACY) add_subdirectory(legacy/core) diff --git a/scripts/correctness.sh b/scripts/correctness.sh new file mode 100755 index 000000000..c0ae008ce --- /dev/null +++ b/scripts/correctness.sh @@ -0,0 +1,106 @@ +#! /bin/bash + +export CASE_DIR=$(pwd)/test/benchgc/cases + +FAIL=0 +set -e + +# bf16 +python3 -m benchgc --verbose 0 --driver linalg --case matmul --md 0:32x128xbf16 --md 1:128x64xbf16 --md 2:32x64xbf16 --cast cast_signed || FAIL=1 + +# f32 + +# misc +python3 -m benchgc --verbose 0 --driver linalg --case fill --md 0:f32 --md 1:32x4096xf32 --cmp 1:P:0:0 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case copy --md 0:1024x1024xf32 --md 1:1024x1024xbf16 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case broadcast --md 0:1024xf32 --md 1:2x32x1024xf32 --dimensions=0 --dimensions=1 || FAIL=1 + +# matmul +python3 -m benchgc --verbose 0 --driver linalg --case batch_matmul --md 0:16x512x64xf32 --md 1:16x64x32xf32 --md 2:16x512x32xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case batch_matmul_transpose_a --md 0:16x512x64xf32 --md 1:16x512x32xf32 --md 2:16x64x32xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case batch_matmul_transpose_b --md 0:16x512x64xf32 --md 1:16x128x64xf32 --md 2:16x512x128xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case batch_matvec --md 0:16x512x64xf32 --md 1:16x64xf32 --md 2:16x512xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case batch_mmt4d --md 0:4x4x8x4x2xf32 --md 1:4x8x8x4x2xf32 --md 2:4x4x8x4x4xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case batch_reduce_matmul --md 0:16x512x64xf32 --md 1:16x64x32xf32 --md 2:512x32xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case batch_vecmat --md 0:16x64xf32 --md 1:16x64x512xf32 --md 2:16x512xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case dot --md 0:4096xf32 --md 1:4096xf32 --md 2:0xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case matmul --md 0:1024x512xf32 --md 1:512x512xf32 --md 2:1024x512xf32 --cast cast_signed || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case matmul_transpose_a --md 0:1024x512xf32 --md 1:1024x512xf32 --md 2:512x512xf32 --cast cast_signed || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case matmul_transpose_b --md 0:1024x512xf32 --md 1:1024x512xf32 --md 2:1024x1024xf32 --cast cast_signed || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case matvec --md 0:512x64xf32 --md 1:64xf32 --md 2:512xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case mmt4d --md 0:4x8x4x2xf32 --md 1:8x8x4x2xf32 --md 2:4x8x4x4xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case vecmat --md 0:512xf32 --md 1:512x64xf32 --md 2:64xf32 || FAIL=1 + +# binary +python3 -m benchgc --verbose 0 --driver linalg --case add --md 0:1x32x4096xf32 --md 1:1x32x4096xf32 --md 2:1x32x4096xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case sub --md 0:1x32x4096xf32 --md 1:1x32x4096xf32 --md 2:1x32x4096xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case mul --md 0:1x32x4096xf32 --md 1:1x32x4096xf32 --md 2:1x32x4096xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case div --md 0:1x32x4096xf32 --md 1:1x32x4096xf32 --md 2:1x32x4096xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case max --md 0:1024x1024xf32 --md 1:1024x1024xf32 --md 2:1024x1024xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case min --md 0:1024x1024xf32 --md 1:1024x1024xf32 --md 2:1024x1024xf32 || FAIL=1 + +# element wise +python3 -m benchgc --verbose 0 --driver linalg --case abs --md 0:32x4096xf32 --md 1:32x4096xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case ceil --md 0:32x4096xf32 --md 1:32x4096xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case erf --md 0:1024x512xf32 --md 1:1024x512xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case floor --md 0:32x4096xf32 --md 1:32x4096xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case log --md 0:4096x32xf32 --md 1:4096x32xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case negf --md 0:32x4096xf32 --md 1:32x4096xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case exp --md 0:32x4096xf32 --md 1:32x4096xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case round --md 0:32x4096xf32 --md 1:32x4096xf32 || FAIL=1 +# python3 -m benchgc --verbose 0 --driver linalg --case rsqrt --md 0:32x4096xf32 --md 1:32x4096xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case sqrt --md 0:32x4096xf32 --md 1:32x4096xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case square --md 0:32x4096xf32 --md 1:32x4096xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case tanh --md 0:128x128xf32 --md 1:128x128xf32 || FAIL=1 + +# conv +python3 -m benchgc --verbose 0 --driver linalg --case conv_1d_ncw_fcw --md 0:4x4x32xf32 --md 1:8x4x4xf32 --md 2:4x8x13xf32 --strides 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case conv_1d_nwc_wcf --md 0:4x32x4xf32 --md 1:4x4x8xf32 --md 2:4x13x8xf32 --strides 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case conv_1d --md 0:32xf32 --md 1:4xf32 --md 2:29xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case conv_2d_nchw_fchw --md 0:4x4x32x32xf32 --md 1:8x4x4x4xf32 --md 2:4x8x13x13xf32 --strides 2 --strides 2 --dilations 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case conv_2d_ngchw_fgchw --md 0:4x2x2x32x32xf32 --md 1:4x2x2x4x4xf32 --md 2:4x2x4x13x13xf32 --strides 2 --strides 2 --dilations 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case conv_2d_ngchw_gfchw --md 0:4x2x2x32x32xf32 --md 1:2x4x2x4x4xf32 --md 2:4x2x4x13x13xf32 --strides 2 --strides 2 --dilations 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case conv_2d_nhwc_fhwc --md 0:4x32x32x4xf32 --md 1:8x4x4x4xf32 --md 2:4x13x13x8xf32 --strides 2 --strides 2 --dilations 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case conv_2d_nhwc_hwcf --md 0:4x32x32x4xf32 --md 1:4x4x4x8xf32 --md 2:4x13x13x8xf32 --strides 2 --strides 2 --dilations 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case conv_2d --md 0:32x32xf32 --md 1:4x4xf32 --md 2:29x29xf32 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case conv_3d_ncdhw_fcdhw --md 0:4x4x32x32x32xf32 --md 1:8x4x4x4x4xf32 --md 2:4x8x13x13x13xf32 --strides 2 --strides 2 --strides 2 --dilations 2 --dilations 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case conv_3d_ndhwc_dhwcf --md 0:4x32x32x32x4xf32 --md 1:4x4x4x4x8xf32 --md 2:4x13x13x13x8xf32 --strides 2 --strides 2 --strides 2 --dilations 2 --dilations 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case conv_3d --md 0:32x32x32xf32 --md 1:4x4x4xf32 --md 2:29x29x29xf32 || FAIL=1 + +# depthwise conv +python3 -m benchgc --verbose 0 --driver linalg --case depthwise_conv_1d_ncw_cw --md 0:4x4x32xf32 --md 1:4x4xf32 --md 2:4x4x13xf32 --strides 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case depthwise_conv_1d_nwc_wc --md 0:4x32x4xf32 --md 1:4x4xf32 --md 2:4x13x4xf32 --strides 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case depthwise_conv_1d_nwc_wcm --md 0:4x32x4xf32 --md 1:4x4x3xf32 --md 2:4x13x4x3xf32 --strides 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case depthwise_conv_2d_nchw_chw --md 0:4x4x32x32xf32 --md 1:4x4x4xf32 --md 2:4x4x13x13xf32 --strides 2 --strides 2 --dilations 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case depthwise_conv_2d_nhwc_hwc --md 0:4x32x32x4xf32 --md 1:4x4x4xf32 --md 2:4x13x13x4xf32 --strides 2 --strides 2 --dilations 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case depthwise_conv_2d_nhwc_hwcm --md 0:4x32x32x4xf32 --md 1:4x4x4x3xf32 --md 2:4x13x13x4x3xf32 --strides 2 --strides 2 --dilations 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case depthwise_conv_3d_ncdhw_cdhw --md 0:4x4x32x32x32xf32 --md 1:4x4x4x4xf32 --md 2:4x4x13x13x13xf32 --strides 2 --strides 2 --strides 2 --dilations 2 --dilations 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case depthwise_conv_3d_ndhwc_dhwc --md 0:4x32x32x32x4xf32 --md 1:4x4x4x4xf32 --md 2:4x13x13x13x4xf32 --strides 2 --strides 2 --strides 2 --dilations 2 --dilations 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case depthwise_conv_3d_ndhwc_dhwcm --md 0:4x32x32x32x4xf32 --md 1:4x4x4x4x3xf32 --md 2:4x13x13x13x4x3xf32 --strides 2 --strides 2 --strides 2 --dilations 2 --dilations 2 --dilations 2 || FAIL=1 + +# pool +python3 -m benchgc --verbose 0 --driver linalg --case pooling_nchw_max --md 0:4x4x32x32xf32 --md 1:4x4xf32 --md 2:4x4x13x13xf32 --strides 2 --strides 2 --dilations 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case pooling_nchw_sum --md 0:4x4x32x32xf32 --md 1:4x4xf32 --md 2:4x4x13x13xf32 --strides 2 --strides 2 --dilations 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case pooling_ncw_max --md 0:4x4x32xf32 --md 1:4xf32 --md 2:4x4x13xf32 --strides 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case pooling_ncw_sum --md 0:4x4x32xf32 --md 1:4xf32 --md 2:4x4x13xf32 --strides 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case pooling_ndhwc_max --md 0:4x32x32x32x4xf32 --md 1:4x4x4xf32 --md 2:4x13x13x13x4xf32 --strides 2 --strides 2 --strides 2 --dilations 2 --dilations 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case pooling_ndhwc_sum --md 0:4x32x32x32x4xf32 --md 1:4x4x4xf32 --md 2:4x13x13x13x4xf32 --strides 2 --strides 2 --strides 2 --dilations 2 --dilations 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case pooling_nhwc_max --md 0:4x32x32x4xf32 --md 1:4x4xf32 --md 2:4x13x13x4xf32 --strides 2 --strides 2 --dilations 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case pooling_nhwc_sum --md 0:4x32x32x4xf32 --md 1:4x4xf32 --md 2:4x13x13x4xf32 --strides 2 --strides 2 --dilations 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case pooling_nhwc_min --md 0:4x32x32x4xf32 --md 1:4x4xf32 --md 2:4x13x13x4xf32 --strides 2 --strides 2 --dilations 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case pooling_nwc_max --md 0:4x32x4xf32 --md 1:4xf32 --md 2:4x13x4xf32 --strides 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case pooling_nwc_sum --md 0:4x32x4xf32 --md 1:4xf32 --md 2:4x13x4xf32 --strides 2 --dilations 2 || FAIL=1 +python3 -m benchgc --verbose 0 --driver linalg --case pooling_nwc_min --md 0:4x32x4xf32 --md 1:4xf32 --md 2:4x13x4xf32 --strides 2 --dilations 2 || FAIL=1 + +# generic / reduce +python3 -m benchgc --verbose 0 --driver mlir --case ${CASE_DIR}/generic.mlir || FAIL=1 +python3 -m benchgc --verbose 0 --driver mlir --case ${CASE_DIR}/reduce.mlir || FAIL=1 + +# softmax +# python3 -m benchgc --verbose 0 --driver linalg --case softmax --md 0:32x4096xf32 --md 1:32x4096xf32 --dimension 1 || FAIL=1 + +# mlir +# python3 -m benchgc --verbose 0 --driver mlir --case ${CASE_DIR}/llama2.mlir || FAIL=1 + +set +e +exit $FAIL \ No newline at end of file diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 4e4036b19..4baaa28de 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -6,3 +6,7 @@ endif () include(gtest) add_subdirectory(dnnl) add_subdirectory(mlir) + +if(GC_BENCH_ENABLE) + add_subdirectory(benchgc) +endif() \ No newline at end of file diff --git a/test/benchgc/.gitignore b/test/benchgc/.gitignore new file mode 100644 index 000000000..0fcd2be1e --- /dev/null +++ b/test/benchgc/.gitignore @@ -0,0 +1,5 @@ +dist/ +src/benchgc.egg-info/ +build +benchgc.egg-info/ +__pycache__ diff --git a/test/benchgc/CMakeLists.txt b/test/benchgc/CMakeLists.txt new file mode 100644 index 000000000..e50f35cf2 --- /dev/null +++ b/test/benchgc/CMakeLists.txt @@ -0,0 +1,41 @@ +################################################################################ +# Copyright (C) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# SPDX-License-Identifier: Apache-2.0 +################################################################################ + +if(NOT GC_BENCH_ENABLE) + message(STATUS "Benchgc is not enabled") + return() +endif() + +configure_file(setup.py ${CMAKE_BINARY_DIR}/test/benchgc/setup.py COPYONLY) + +file(GLOB PYTHON_SCRIPTS "*.py") +foreach(PY_SCRIPT ${PYTHON_SCRIPTS}) + configure_file(${PY_SCRIPT} ${CMAKE_BINARY_DIR} COPYONLY) +endforeach() + +find_package(Python 3.10 REQUIRED COMPONENTS Interpreter NumPy REQUIRED) +add_custom_target(benchgc + COMMAND ${Python_EXECUTABLE} setup.py bdist_wheel + WORKING_DIRECTORY "${CMAKE_BINARY_DIR}/test/benchgc/" + DEPENDS GcPythonModules) + +add_subdirectory("src/benchgc") +add_subdirectory("src/benchgc/arg") +add_subdirectory("src/benchgc/mlir") +add_subdirectory("src/benchgc/linalg") +add_subdirectory("src/benchgc/tensor") +add_subdirectory("src/benchgc/arith") diff --git a/test/benchgc/README.md b/test/benchgc/README.md new file mode 100644 index 000000000..77499c5fd --- /dev/null +++ b/test/benchgc/README.md @@ -0,0 +1,257 @@ +# benchgc - benchmark tool for graph compiler + +## Description + +Benchgc is a tool used to verify the correctness and performance of graph compiler. Benchgc accepts MLIR files based on the OneDNN graph dialect as test cases and prepares test data for them. For correctness verification, Benchgc will use PyTorch as a reference for comparison. + +## Prerequisite +* python >= 3.10 +* torch >= 2.2 +* pybind11 + +## Build and install +``` +# Please execute at the top level of the project + +mkdir -p build +cd build + +cmake .. -DMLIR_DIR=$MLIR_PATH -DGC_TEST_ENABLE=ON -DGC_ENABLE_BINDINGS_PYTHON=ON -DGC_BENCH_ENABLE=ON +make -j benchgc + +python -m pip install test/benchgc/dist/benchgc-*.whl + +``` + +## Synopsis +``` +python -m benchgc [OPTIONS] --driver [DRIVER] --case [CASE] +``` +## Flags +### --driver [str] +* linalg: test the single op in linalg dialect +* mlir: upload a mlir file and run +* pattern: predefined pattern test such as mlp + +### --case [str] +* if driver=mlir, please provide a mlir file here to test +* if driver=pattern, please provide the pre-defined pattern name, such as mlp here +* if driver is a dialect name, please provide the detail op name to start a single op test + +### --seed [int] +* set the seed to generate the test data and reprodce the test + +### --verbose [int] +* set the verbose level + +### --md index:SHAPExTYPE +* Describe the shape and data type for argument +* Not available when driver=mlir +* index means the order of argument, including both inputs and outs +* use prefix `0x` (e.g. `0xbf16`) to represent 0d memref or tensor input +* use data type directly (e.g.`f32`) to represent a normal scalar + +``` +# %arg0 -> index = 0 +# tensor<2x2x2xf32> -> index = 1 + +module { + func.func @entry(%arg0: f32) -> tensor<2x2x2xf32> attributes {llvm.emit_c_interface} { + %0 = tensor.empty() : tensor<2x2x2xf32> + %1 = linalg.fill ins(%arg0 : f32) outs(%0 : tensor<2x2x2xf32>) -> tensor<2x2x2xf32> + return %1 : tensor<2x2x2xf32> + } +} +``` + +### --fill index:fill_type:[:fill_parameter]* +* If not set, benchgc will assign a default method for the argument + +| description | fill_type | fill_parameter | +|-------------|-----------|-----------| +| Zero | Z | | +| Normal | N | mean, std | +| Poisson | P | lambda | +| Binomial | B | n, p | +| Uniform | U | a, b | +| Integer | I | a, b | +| Pytorch tensor dump | F | dump filename | +| Benchdnn driver | D | driver_name[:driver filling parameter]* | + +#### Benchdnn driver filling + +| driver_name | driver filling parameter | +|-------------|--------------------------| +| binary | src0/src1:src0 dtype:src1 dtype:dst dtype | +| conv | src/wei:src dtype:wei dtype:dst dtype:amplifier | +| eltwise | algorithm: alpha: beta (please check https://oneapi-src.github.io/oneDNN/dev_guide_eltwise.html) | +| matmul | src/wei:src dtype:wei dtype:dst dtype:amplifier | +| pool | not required | + +### --cmp index:cmp_type:[:cmp_parameter]* +* If not set, benchgc will assign a default method for the argument + +| description | cmp_type | cmp_parameter | +|-------------|-----------|-----------| +| P2P check | P | threshold, zero_percent(mistrust check) | +| Norm check | N | threshold | +| Benchdnn driver | D | driver_name:dtype:case | + +## Example +``` +# single add op test +# using the same data filling / compare strategy as the benchdnn primitive driver if not set +python3 -m benchgc --verbose 6 --driver linalg --case add --md 0:4x5xf32 --md 1:4x5xf32 --md 2:4x5xf32 + +arg0 shape: [4, 5] dtype: f32 fill_type: D fill_param: ['binary', 'src0', 'f32', 'f32', 'f32'] cmp_type: D cmp_param: ['binary', 'f32', 'add'] +arg1 shape: [4, 5] dtype: f32 fill_type: D fill_param: ['binary', 'src1', 'f32', 'f32', 'f32'] cmp_type: D cmp_param: ['binary', 'f32', 'add'] +arg2 shape: [4, 5] dtype: f32 fill_type: Z fill_param: [] cmp_type: D cmp_param: ['binary', 'f32', 'add'] +module { + func.func @entry(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5xf32>) -> tensor<4x5xf32> attributes {llvm.emit_c_interface} { + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<4x5xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<4x5xf32>) -> tensor<4x5xf32> + %2 = linalg.add ins(%arg0, %arg1 : tensor<4x5xf32>, tensor<4x5xf32>) outs(%1 : tensor<4x5xf32>) -> tensor<4x5xf32> + return %2 : tensor<4x5xf32> + } +} + +fill arg0: +tensor([[ -5.0000, 10.0000, 3.7500, -2.5000, -8.7500], + [ 6.2500, 0.0000, -6.2500, 8.7500, 2.5000], + [ -3.7500, -10.0000, 5.0000, -1.2500, -7.5000], + [ 7.5000, 1.2500, -5.0000, 10.0000, 3.7500]]) +fill arg1: +tensor([[ 1.2500, -5.0000, 10.0000, 3.7500, -2.5000], + [ -8.7500, 6.2500, 1.0000, -6.2500, 8.7500], + [ 2.5000, -3.7500, -10.0000, 5.0000, -1.2500], + [ -7.5000, 7.5000, 1.2500, -5.0000, 10.0000]]) +fill arg2: +tensor([[0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.]]) +p2p check: threshold: 0.0000001 + (0, 0): ref: -5.0000000 res: -5.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (0, 1): ref: 10.0000000 res: 10.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (0, 2): ref: 3.7500000 res: 3.7500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (0, 3): ref: -2.5000000 res: -2.5000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (0, 4): ref: -8.7500000 res: -8.7500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 0): ref: 6.2500000 res: 6.2500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 1): ref: 0.0000000 res: 0.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 2): ref: -6.2500000 res: -6.2500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 3): ref: 8.7500000 res: 8.7500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 4): ref: 2.5000000 res: 2.5000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (2, 0): ref: -3.7500000 res: -3.7500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (2, 1): ref: -10.0000000 res: -10.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (2, 2): ref: 5.0000000 res: 5.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (2, 3): ref: -1.2500000 res: -1.2500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (2, 4): ref: -7.5000000 res: -7.5000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (3, 0): ref: 7.5000000 res: 7.5000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (3, 1): ref: 1.2500000 res: 1.2500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (3, 2): ref: -5.0000000 res: -5.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (3, 3): ref: 10.0000000 res: 10.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (3, 4): ref: 3.7500000 res: 3.7500000 abs_diff: 0.0000000 rel_diff: 0.0000000 +p2p check: threshold: 0.0000001 + (0, 0): ref: 1.2500000 res: 1.2500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (0, 1): ref: -5.0000000 res: -5.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (0, 2): ref: 10.0000000 res: 10.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (0, 3): ref: 3.7500000 res: 3.7500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (0, 4): ref: -2.5000000 res: -2.5000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 0): ref: -8.7500000 res: -8.7500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 1): ref: 6.2500000 res: 6.2500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 2): ref: 1.0000000 res: 1.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 3): ref: -6.2500000 res: -6.2500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 4): ref: 8.7500000 res: 8.7500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (2, 0): ref: 2.5000000 res: 2.5000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (2, 1): ref: -3.7500000 res: -3.7500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (2, 2): ref: -10.0000000 res: -10.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (2, 3): ref: 5.0000000 res: 5.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (2, 4): ref: -1.2500000 res: -1.2500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (3, 0): ref: -7.5000000 res: -7.5000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (3, 1): ref: 7.5000000 res: 7.5000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (3, 2): ref: 1.2500000 res: 1.2500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (3, 3): ref: -5.0000000 res: -5.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (3, 4): ref: 10.0000000 res: 10.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 +p2p check: threshold: 0.0000001 + (0, 0): ref: -3.7500000 res: -3.7500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (0, 1): ref: 5.0000000 res: 5.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (0, 2): ref: 13.7500000 res: 13.7500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (0, 3): ref: 1.2500000 res: 1.2500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (0, 4): ref: -11.2500000 res: -11.2500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 0): ref: -2.5000000 res: -2.5000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 1): ref: 6.2500000 res: 6.2500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 2): ref: -5.2500000 res: -5.2500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 3): ref: 2.5000000 res: 2.5000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 4): ref: 11.2500000 res: 11.2500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (2, 0): ref: -1.2500000 res: -1.2500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (2, 1): ref: -13.7500000 res: -13.7500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (2, 2): ref: -5.0000000 res: -5.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (2, 3): ref: 3.7500000 res: 3.7500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (2, 4): ref: -8.7500000 res: -8.7500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (3, 0): ref: 0.0000000 res: 0.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (3, 1): ref: 8.7500000 res: 8.7500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (3, 2): ref: -3.7500000 res: -3.7500000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (3, 3): ref: 5.0000000 res: 5.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (3, 4): ref: 13.7500000 res: 13.7500000 abs_diff: 0.0000000 rel_diff: 0.0000000 +PASSED: linalg.add +``` + +``` +# set the arg0 filling follows a distribution N(0, 5) +# set the arg1 filling follows a uniform integer filling [-3, 3] +# use P2P compare strategy on arg2 with threshold = 0 & mistrust rate = 100.0% +# zero threshold will fail the case here + +python3 -m benchgc --verbose 6 --driver linalg --case matmul_transpose_b --md 0:2x5xf32 --md 1:2x5xf32 --md 2:2x2xf32 --fill 0:N:0:5 --fill 1:I:-3:3 --cmp 2:P:0:100 +arg0 shape: [2, 5] dtype: f32 fill_type: N fill_param: ['0', '5'] cmp_type: D cmp_param: ['matmul', 'f32', 'matmul_transpose_b'] +arg1 shape: [2, 5] dtype: f32 fill_type: I fill_param: ['-3', '3'] cmp_type: D cmp_param: ['matmul', 'f32', 'matmul_transpose_b'] +arg2 shape: [2, 2] dtype: f32 fill_type: Z fill_param: [] cmp_type: P cmp_param: ['0', '100'] +module { + func.func @entry(%arg0: tensor<2x5xf32>, %arg1: tensor<2x5xf32>) -> tensor<2x2xf32> attributes {llvm.emit_c_interface} { + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<2x2xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x2xf32>) -> tensor<2x2xf32> + %2 = linalg.matmul_transpose_b {cast = #linalg.type_fn} ins(%arg0, %arg1 : tensor<2x5xf32>, tensor<2x5xf32>) outs(%1 : tensor<2x2xf32>) -> tensor<2x2xf32> + return %2 : tensor<2x2xf32> + } +} + +fill arg0: +tensor([[ 7.7050, -1.4671, -10.8939, 2.8422, -5.4226], + [ -6.9930, 2.0167, 4.1901, -3.5963, -2.0167]]) +fill arg1: +tensor([[-3., 0., 1., 0., 0.], + [ 3., -3., 2., -3., 0.]]) +fill arg2: +tensor([[0., 0.], + [0., 0.]]) +p2p check: threshold: 0.0000010 + (0, 0): ref: 7.7049804 res: 7.7049804 abs_diff: 0.0000000 rel_diff: 0.0000000 + (0, 1): ref: -1.4671445 res: -1.4671445 abs_diff: 0.0000000 rel_diff: 0.0000000 + (0, 2): ref: -10.8939466 res: -10.8939466 abs_diff: 0.0000000 rel_diff: 0.0000000 + (0, 3): ref: 2.8421564 res: 2.8421564 abs_diff: 0.0000000 rel_diff: 0.0000000 + (0, 4): ref: -5.4226117 res: -5.4226117 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 0): ref: -6.9929771 res: -6.9929771 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 1): ref: 2.0167341 res: 2.0167341 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 2): ref: 4.1901317 res: 4.1901317 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 3): ref: -3.5962880 res: -3.5962880 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 4): ref: -2.0167177 res: -2.0167177 abs_diff: 0.0000000 rel_diff: 0.0000000 +p2p check: threshold: 0.0000010 + (0, 0): ref: -3.0000000 res: -3.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (0, 1): ref: 0.0000000 res: 0.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (0, 2): ref: 1.0000000 res: 1.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (0, 3): ref: 0.0000000 res: 0.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (0, 4): ref: 0.0000000 res: 0.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 0): ref: 3.0000000 res: 3.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 1): ref: -3.0000000 res: -3.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 2): ref: 2.0000000 res: 2.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 3): ref: -3.0000000 res: -3.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 4): ref: 0.0000000 res: 0.0000000 abs_diff: 0.0000000 rel_diff: 0.0000000 +p2p check: threshold: 0.0000000 + (0, 0): ref: -34.0088882 res: -34.0088882 abs_diff: 0.0000000 rel_diff: 0.0000000 + (0, 1): ref: -2.7979884 res: -2.7979879 abs_diff: 0.0000005 rel_diff: 0.0000002 + (1, 0): ref: 25.1690636 res: 25.1690636 abs_diff: 0.0000000 rel_diff: 0.0000000 + (1, 1): ref: -7.8600063 res: -7.8600044 abs_diff: 0.0000019 rel_diff: 0.0000002 +FAIL: linalg.matmul_transpose_b +``` \ No newline at end of file diff --git a/test/benchgc/cases/generic.mlir b/test/benchgc/cases/generic.mlir new file mode 100644 index 000000000..3da555777 --- /dev/null +++ b/test/benchgc/cases/generic.mlir @@ -0,0 +1,15 @@ +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +module { + func.func @entry(%arg0: tensor<3x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<3x3xf32> attributes {llvm.emit_c_interface} { + %0 = tensor.empty() : tensor<3x3xf32> + %1 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<3x2xf32>, tensor<2x3xf32>) outs(%0 : tensor<3x3xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %2 = arith.mulf %in, %in_0 : f32 + %3 = arith.addf %out, %2 : f32 + linalg.yield %3 : f32 + } -> tensor<3x3xf32> + return %1 : tensor<3x3xf32> + } +} \ No newline at end of file diff --git a/test/benchgc/cases/llama2.mlir b/test/benchgc/cases/llama2.mlir new file mode 100644 index 000000000..1f557b3e6 --- /dev/null +++ b/test/benchgc/cases/llama2.mlir @@ -0,0 +1,113 @@ +module { + func.func @entry(%arg0: tensor<1x32x4096xbf16>, %arg1: tensor<4096x4096xbf16>, %arg2: tensor<1x32x4096xbf16>, %arg3: tensor<1xf32>, %arg4: tensor<4096xbf16>, %arg5: tensor<11008x4096xbf16>, %arg6: tensor<11008x4096xbf16>, %arg7: tensor<4096x11008xbf16>, %arg8: tensor<1xf32>, %arg9: tensor<4096xbf16>) -> (tensor<1x32x4096xbf16>, tensor<1x32x4096xbf16>) attributes {llvm.emit_c_interface} { + %collapsed = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor<1x32x4096xbf16> into tensor<32x4096xbf16> + %cst = arith.constant 0.000000e+00 : bf16 + %0 = tensor.empty() : tensor<32x4096xbf16> + %1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<32x4096xbf16>) -> tensor<32x4096xbf16> + %2 = linalg.matmul_transpose_b ins(%collapsed, %arg1 : tensor<32x4096xbf16>, tensor<4096x4096xbf16>) outs(%1 : tensor<32x4096xbf16>) -> tensor<32x4096xbf16> + %expanded = tensor.expand_shape %2 [[0, 1], [2]] output_shape [1, 32, 4096] : tensor<32x4096xbf16> into tensor<1x32x4096xbf16> + %3 = tensor.empty() : tensor<1x32x4096xbf16> + %4 = linalg.add ins(%arg2, %expanded : tensor<1x32x4096xbf16>, tensor<1x32x4096xbf16>) outs(%3 : tensor<1x32x4096xbf16>) -> tensor<1x32x4096xbf16> + %5 = tensor.empty() : tensor<1x32x4096xf32> + %6 = linalg.copy ins(%4 : tensor<1x32x4096xbf16>) outs(%5 : tensor<1x32x4096xf32>) -> tensor<1x32x4096xf32> + %cst_0 = arith.constant dense<2.000000e+00> : tensor<1x32x4096xf32> + %7 = tensor.empty() : tensor<1x32x4096xf32> + %8 = linalg.powf ins(%6, %cst_0 : tensor<1x32x4096xf32>, tensor<1x32x4096xf32>) outs(%7 : tensor<1x32x4096xf32>) -> tensor<1x32x4096xf32> + %cst_1 = arith.constant 0.000000e+00 : f32 + %9 = tensor.empty() : tensor<1x32xf32> + %10 = linalg.fill ins(%cst_1 : f32) outs(%9 : tensor<1x32xf32>) -> tensor<1x32xf32> + %reduced = linalg.reduce ins(%8 : tensor<1x32x4096xf32>) outs(%10 : tensor<1x32xf32>) dimensions = [2] + (%in: f32, %init: f32) { + %67 = arith.addf %in, %init : f32 + linalg.yield %67 : f32 + } + %cst_2 = arith.constant dense<4.096000e+03> : tensor<1x32xf32> + %11 = tensor.empty() : tensor<1x32xf32> + %12 = linalg.div ins(%reduced, %cst_2 : tensor<1x32xf32>, tensor<1x32xf32>) outs(%11 : tensor<1x32xf32>) -> tensor<1x32xf32> + %expanded_3 = tensor.expand_shape %12 [[0], [1, 2]] output_shape [1, 32, 1] : tensor<1x32xf32> into tensor<1x32x1xf32> + %13 = tensor.empty() : tensor<1x32x1xf32> + %broadcasted = linalg.broadcast ins(%arg8 : tensor<1xf32>) outs(%13 : tensor<1x32x1xf32>) dimensions = [0, 1] + %14 = tensor.empty() : tensor<1x32x1xf32> + %15 = linalg.add ins(%expanded_3, %broadcasted : tensor<1x32x1xf32>, tensor<1x32x1xf32>) outs(%14 : tensor<1x32x1xf32>) -> tensor<1x32x1xf32> + %cst_4 = arith.constant dense<-5.000000e-01> : tensor<1x32x1xf32> + %16 = tensor.empty() : tensor<1x32x1xf32> + %17 = linalg.powf ins(%15, %cst_4 : tensor<1x32x1xf32>, tensor<1x32x1xf32>) outs(%16 : tensor<1x32x1xf32>) -> tensor<1x32x1xf32> + %collapsed_5 = tensor.collapse_shape %17 [[0], [1, 2]] : tensor<1x32x1xf32> into tensor<1x32xf32> + %18 = tensor.empty() : tensor<1x32x4096xf32> + %broadcasted_6 = linalg.broadcast ins(%collapsed_5 : tensor<1x32xf32>) outs(%18 : tensor<1x32x4096xf32>) dimensions = [2] + %19 = tensor.empty() : tensor<1x32x4096xf32> + %20 = linalg.mul ins(%6, %broadcasted_6 : tensor<1x32x4096xf32>, tensor<1x32x4096xf32>) outs(%19 : tensor<1x32x4096xf32>) -> tensor<1x32x4096xf32> + %21 = tensor.empty() : tensor<1x32x4096xbf16> + %22 = linalg.copy ins(%20 : tensor<1x32x4096xf32>) outs(%21 : tensor<1x32x4096xbf16>) -> tensor<1x32x4096xbf16> + %23 = tensor.empty() : tensor<1x32x4096xbf16> + %broadcasted_7 = linalg.broadcast ins(%arg4 : tensor<4096xbf16>) outs(%23 : tensor<1x32x4096xbf16>) dimensions = [0, 1] + %24 = tensor.empty() : tensor<1x32x4096xbf16> + %25 = linalg.mul ins(%broadcasted_7, %22 : tensor<1x32x4096xbf16>, tensor<1x32x4096xbf16>) outs(%24 : tensor<1x32x4096xbf16>) -> tensor<1x32x4096xbf16> + %collapsed_8 = tensor.collapse_shape %25 [[0, 1], [2]] : tensor<1x32x4096xbf16> into tensor<32x4096xbf16> + %cst_9 = arith.constant 0.000000e+00 : bf16 + %26 = tensor.empty() : tensor<32x11008xbf16> + %27 = linalg.fill ins(%cst_9 : bf16) outs(%26 : tensor<32x11008xbf16>) -> tensor<32x11008xbf16> + %28 = linalg.matmul_transpose_b ins(%collapsed_8, %arg5 : tensor<32x4096xbf16>, tensor<11008x4096xbf16>) outs(%27 : tensor<32x11008xbf16>) -> tensor<32x11008xbf16> + %expanded_10 = tensor.expand_shape %28 [[0, 1], [2]] output_shape [1, 32, 11008] : tensor<32x11008xbf16> into tensor<1x32x11008xbf16> + %29 = tensor.empty() : tensor<1x32x11008xbf16> + %cst_11 = arith.constant dense<1.000000e+00> : tensor<1x32x11008xbf16> + %30 = linalg.negf ins(%expanded_10 : tensor<1x32x11008xbf16>) outs(%29 : tensor<1x32x11008xbf16>) -> tensor<1x32x11008xbf16> + %31 = linalg.exp ins(%30 : tensor<1x32x11008xbf16>) outs(%29 : tensor<1x32x11008xbf16>) -> tensor<1x32x11008xbf16> + %32 = linalg.add ins(%cst_11, %31 : tensor<1x32x11008xbf16>, tensor<1x32x11008xbf16>) outs(%29 : tensor<1x32x11008xbf16>) -> tensor<1x32x11008xbf16> + %33 = linalg.div ins(%cst_11, %32 : tensor<1x32x11008xbf16>, tensor<1x32x11008xbf16>) outs(%29 : tensor<1x32x11008xbf16>) -> tensor<1x32x11008xbf16> + %34 = tensor.empty() : tensor<1x32x11008xbf16> + %35 = linalg.mul ins(%33, %expanded_10 : tensor<1x32x11008xbf16>, tensor<1x32x11008xbf16>) outs(%34 : tensor<1x32x11008xbf16>) -> tensor<1x32x11008xbf16> + %collapsed_12 = tensor.collapse_shape %25 [[0, 1], [2]] : tensor<1x32x4096xbf16> into tensor<32x4096xbf16> + %cst_13 = arith.constant 0.000000e+00 : bf16 + %36 = tensor.empty() : tensor<32x11008xbf16> + %37 = linalg.fill ins(%cst_13 : bf16) outs(%36 : tensor<32x11008xbf16>) -> tensor<32x11008xbf16> + %38 = linalg.matmul_transpose_b ins(%collapsed_12, %arg6 : tensor<32x4096xbf16>, tensor<11008x4096xbf16>) outs(%37 : tensor<32x11008xbf16>) -> tensor<32x11008xbf16> + %expanded_14 = tensor.expand_shape %38 [[0, 1], [2]] output_shape [1, 32, 11008] : tensor<32x11008xbf16> into tensor<1x32x11008xbf16> + %39 = tensor.empty() : tensor<1x32x11008xbf16> + %40 = linalg.mul ins(%35, %expanded_14 : tensor<1x32x11008xbf16>, tensor<1x32x11008xbf16>) outs(%39 : tensor<1x32x11008xbf16>) -> tensor<1x32x11008xbf16> + %collapsed_15 = tensor.collapse_shape %40 [[0, 1], [2]] : tensor<1x32x11008xbf16> into tensor<32x11008xbf16> + %cst_16 = arith.constant 0.000000e+00 : bf16 + %41 = tensor.empty() : tensor<32x4096xbf16> + %42 = linalg.fill ins(%cst_16 : bf16) outs(%41 : tensor<32x4096xbf16>) -> tensor<32x4096xbf16> + %43 = linalg.matmul_transpose_b ins(%collapsed_15, %arg7 : tensor<32x11008xbf16>, tensor<4096x11008xbf16>) outs(%42 : tensor<32x4096xbf16>) -> tensor<32x4096xbf16> + %expanded_17 = tensor.expand_shape %43 [[0, 1], [2]] output_shape [1, 32, 4096] : tensor<32x4096xbf16> into tensor<1x32x4096xbf16> + %44 = tensor.empty() : tensor<1x32x4096xbf16> + %45 = linalg.add ins(%4, %expanded_17 : tensor<1x32x4096xbf16>, tensor<1x32x4096xbf16>) outs(%44 : tensor<1x32x4096xbf16>) -> tensor<1x32x4096xbf16> + %46 = tensor.empty() : tensor<1x32x4096xf32> + %47 = linalg.copy ins(%45 : tensor<1x32x4096xbf16>) outs(%46 : tensor<1x32x4096xf32>) -> tensor<1x32x4096xf32> + %cst_18 = arith.constant dense<2.000000e+00> : tensor<1x32x4096xf32> + %48 = tensor.empty() : tensor<1x32x4096xf32> + %49 = linalg.powf ins(%47, %cst_18 : tensor<1x32x4096xf32>, tensor<1x32x4096xf32>) outs(%48 : tensor<1x32x4096xf32>) -> tensor<1x32x4096xf32> + %cst_19 = arith.constant 0.000000e+00 : f32 + %50 = tensor.empty() : tensor<1x32xf32> + %51 = linalg.fill ins(%cst_19 : f32) outs(%50 : tensor<1x32xf32>) -> tensor<1x32xf32> + %reduced_20 = linalg.reduce ins(%49 : tensor<1x32x4096xf32>) outs(%51 : tensor<1x32xf32>) dimensions = [2] + (%in: f32, %init: f32) { + %67 = arith.addf %in, %init : f32 + linalg.yield %67 : f32 + } + %cst_21 = arith.constant dense<4.096000e+03> : tensor<1x32xf32> + %52 = tensor.empty() : tensor<1x32xf32> + %53 = linalg.div ins(%reduced_20, %cst_21 : tensor<1x32xf32>, tensor<1x32xf32>) outs(%52 : tensor<1x32xf32>) -> tensor<1x32xf32> + %expanded_22 = tensor.expand_shape %53 [[0], [1, 2]] output_shape [1, 32, 1] : tensor<1x32xf32> into tensor<1x32x1xf32> + %54 = tensor.empty() : tensor<1x32x1xf32> + %broadcasted_23 = linalg.broadcast ins(%arg8 : tensor<1xf32>) outs(%54 : tensor<1x32x1xf32>) dimensions = [0, 1] + %55 = tensor.empty() : tensor<1x32x1xf32> + %56 = linalg.add ins(%expanded_22, %broadcasted_23 : tensor<1x32x1xf32>, tensor<1x32x1xf32>) outs(%55 : tensor<1x32x1xf32>) -> tensor<1x32x1xf32> + %cst_24 = arith.constant dense<-5.000000e-01> : tensor<1x32x1xf32> + %57 = tensor.empty() : tensor<1x32x1xf32> + %58 = linalg.powf ins(%56, %cst_24 : tensor<1x32x1xf32>, tensor<1x32x1xf32>) outs(%57 : tensor<1x32x1xf32>) -> tensor<1x32x1xf32> + %collapsed_25 = tensor.collapse_shape %58 [[0], [1, 2]] : tensor<1x32x1xf32> into tensor<1x32xf32> + %59 = tensor.empty() : tensor<1x32x4096xf32> + %broadcasted_26 = linalg.broadcast ins(%collapsed_25 : tensor<1x32xf32>) outs(%59 : tensor<1x32x4096xf32>) dimensions = [2] + %60 = tensor.empty() : tensor<1x32x4096xf32> + %61 = linalg.mul ins(%47, %broadcasted_26 : tensor<1x32x4096xf32>, tensor<1x32x4096xf32>) outs(%60 : tensor<1x32x4096xf32>) -> tensor<1x32x4096xf32> + %62 = tensor.empty() : tensor<1x32x4096xbf16> + %63 = linalg.copy ins(%61 : tensor<1x32x4096xf32>) outs(%62 : tensor<1x32x4096xbf16>) -> tensor<1x32x4096xbf16> + %64 = tensor.empty() : tensor<1x32x4096xbf16> + %broadcasted_27 = linalg.broadcast ins(%arg9 : tensor<4096xbf16>) outs(%64 : tensor<1x32x4096xbf16>) dimensions = [0, 1] + %65 = tensor.empty() : tensor<1x32x4096xbf16> + %66 = linalg.mul ins(%broadcasted_27, %63 : tensor<1x32x4096xbf16>, tensor<1x32x4096xbf16>) outs(%65 : tensor<1x32x4096xbf16>) -> tensor<1x32x4096xbf16> + return %66, %45 : tensor<1x32x4096xbf16>, tensor<1x32x4096xbf16> + } +} \ No newline at end of file diff --git a/test/benchgc/cases/reduce.mlir b/test/benchgc/cases/reduce.mlir new file mode 100644 index 000000000..8183319a9 --- /dev/null +++ b/test/benchgc/cases/reduce.mlir @@ -0,0 +1,12 @@ +module { + func.func @entry(%arg0: tensor<3x5xf32>) -> tensor<3xf32> attributes {llvm.emit_c_interface} { + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<3xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<3xf32>) -> tensor<3xf32> + %reduce = linalg.reduce { arith.addf } + ins(%arg0:tensor<3x5xf32>) + outs(%1:tensor<3xf32>) + dimensions = [1] + return %reduce : tensor<3xf32> + } +} \ No newline at end of file diff --git a/test/benchgc/setup.py b/test/benchgc/setup.py new file mode 100644 index 000000000..3d67af539 --- /dev/null +++ b/test/benchgc/setup.py @@ -0,0 +1,30 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import setuptools + +setuptools.setup( + name="benchgc", + description="benchmark tool for graph compiler", + package_dir={ + "benchgc": "src/benchgc", + "gc_mlir": "../../python_packages/gc_mlir_core/gc_mlir", + }, + packages=setuptools.find_packages("src") + + setuptools.find_namespace_packages("../../python_packages/gc_mlir_core"), + package_data={"gc_mlir": ["_mlir_libs/*.so"]}, + install_requires=["torch", "numpy", "ml_dtypes"], +) diff --git a/test/benchgc/src/benchgc/CMakeLists.txt b/test/benchgc/src/benchgc/CMakeLists.txt new file mode 100644 index 000000000..5700c4f36 --- /dev/null +++ b/test/benchgc/src/benchgc/CMakeLists.txt @@ -0,0 +1,22 @@ +################################################################################ +# Copyright (C) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# SPDX-License-Identifier: Apache-2.0 +################################################################################ + + +file(GLOB PYTHON_SCRIPTS "*.py") +foreach(PY_SCRIPT ${PYTHON_SCRIPTS}) + configure_file(${PY_SCRIPT} ${CMAKE_BINARY_DIR}/test/benchgc/src/benchgc/ COPYONLY) +endforeach() diff --git a/test/benchgc/src/benchgc/__init__.py b/test/benchgc/src/benchgc/__init__.py new file mode 100644 index 000000000..3b87b7051 --- /dev/null +++ b/test/benchgc/src/benchgc/__init__.py @@ -0,0 +1,20 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import pathlib +import sys + +sys.path.append(pathlib.Path(__file__).parent.resolve().__str__()) diff --git a/test/benchgc/src/benchgc/__main__.py b/test/benchgc/src/benchgc/__main__.py new file mode 100644 index 000000000..1470183a6 --- /dev/null +++ b/test/benchgc/src/benchgc/__main__.py @@ -0,0 +1,272 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + + +import argparse +import sys +from typing import Dict, List + +import benchgc.mlir.util +import benchgc.util +import gc_mlir.ir +import runner +import torch +from benchgc.arg import ( + compare_tensor, + fill_tensor, + set_default_compare, + set_default_fill, +) +from benchgc.arg.arg import Arg +from benchgc.mlir.arg import get_mlir_args +from gc_mlir.graph_compiler import GraphCompiler + +try: + parser = argparse.ArgumentParser(prog="benchmark tool for graph compiler") + parser.add_argument( + "--driver", + required=False, + help="specify the test driver", + choices=["linalg", "tensor", "mlir", "pattern"], + type=str, + ) + parser.add_argument( + "--case", + required=False, + help="test which operation in the specified driver", + type=str, + ) + + parser.add_argument( + "--md", + required=False, + help="format: #ARG:SHAPExTYPE", + type=str, + default=[], + action="append", + ) + parser.add_argument( + "--fill", + required=False, + help="format: #ARG:type:parameter", + type=str, + default=[], + action="append", + ) + parser.add_argument( + "--cmp", + required=False, + help="format: #ARG:type:parameter", + type=str, + default=[], + action="append", + ) + + parser.add_argument( + "--seed", + required=False, + default=0, + type=int, + help="a seed value to generate data filling", + ) + parser.add_argument( + "--verbose", + type=int, + default=benchgc.util.NO_VERBOSE, + help="verbose level", + choices=[ + benchgc.util.NO_VERBOSE, + benchgc.util.MODULE_VERBOSE, + benchgc.util.ARG_VERBOSE, + benchgc.util.COMPARE_VERBOSE, + benchgc.util.ERROR_OUTPUT_VERBOSE, + benchgc.util.OUTPUT_VERBOSE, + benchgc.util.INPUT_VERBOSE, + ], + ) + parser.add_argument( + "--cast", + required=False, + default="cast_signed", + help="define attribute supported by linalg op such as matmul_transpose_b", + choices=["cast_signed", "cast_unsigned"], + type=str, + ) + + # single dimension index + # linalg.softmax + parser.add_argument( + "--dimension", + required=False, + default=None, + help="define the dimension attribute in linalg op", + type=int, + ) + + # multiple dimensions array + # linalg.broadcast / linalg.reduce + parser.add_argument( + "--dimensions", + required=False, + default=None, + action="append", + help="define the dimensions attribute in linalg op", + type=int, + ) + + parser.add_argument( + "--dilations", + required=False, + default=None, + action="append", + help="define the dilations attribute in linalg op", + type=int, + ) + + parser.add_argument( + "--strides", + required=False, + default=None, + action="append", + help="define the strides attribute in linalg op", + type=int, + ) + flags = parser.parse_args() + benchgc.util.set_seed(flags.seed) + + +except argparse.ArgumentError: + sys.stderr.write("Argument parse failed\n") + sys.exit(1) + +args: List[Arg] = [] + +if flags.driver == "mlir": + # we need to find all args by reading the entry function + with open(flags.case, "r") as mlir_file: + with gc_mlir.ir.Context() as ctx: + module = gc_mlir.ir.Module.parse(mlir_file.read()) + entry = benchgc.mlir.util.get_entry(module) + idx: int = 0 + # FIXME: only support RankTensorType now + for i in entry.type.inputs: + args.append(Arg(idx)) + args[-1].dtype = str(i.element_type) + args[-1].shape = list(i.shape) + args[-1].set_scalar() + idx += 1 + + for o in entry.type.results: + args.append(Arg(idx)) + args[-1].dtype = str(o.element_type) + args[-1].shape = list(o.shape) + args[-1].set_scalar() + idx += 1 +elif flags.driver in ["linalg"]: + # all arg shape/dt should be provided in single op test + for i in range(len(flags.md)): + args.append(Arg(i)) + + for md in flags.md: + colon = md.find(":") + if colon == -1: + raise Exception("Wrong md format: %s", md) + idx = int(md[:colon]) + args[idx].set_md(md[colon + 1 :]) + + from .linalg import mlir_op + + mlir_func = mlir_op[flags.case] + module = mlir_func(flags, args) +else: + raise Exception(f"unsupported driver {flags.driver}") + +for fill in flags.fill: + colon = fill.find(":") + if colon == -1: + raise Exception("Wrong fill format: %s", fill) + idx = int(fill[:colon]) + args[idx].set_fill(fill[colon + 1 :]) + +for cmp in flags.cmp: + colon = cmp.find(":") + if colon == -1: + raise Exception("Wrong cmp format: %s", cmp) + idx = int(cmp[:colon]) + args[idx].set_cmp(cmp[colon + 1 :]) + +entry = benchgc.mlir.util.get_entry(module) + +for i, arg in enumerate(args): + # use zero filling if the arg is return value + set_default_fill(flags, arg, args, i >= len(entry.type.inputs)) + set_default_compare(flags, arg, args, i >= len(entry.type.inputs)) + +for arg in args: + arg.print_verbose(flags.verbose) + +if flags.verbose >= benchgc.util.MODULE_VERBOSE: + print(module) + +ref_args: List[torch.Tensor] = [] +gc_args: List[torch.Tensor | int] = [] +ref_tensors: Dict[str, torch.Tensor] = {} +gc_tensors: Dict[str, torch.Tensor] = {} + +for i in range(len(args)): + tensor = fill_tensor(flags, args[i], i) + gc_tensors["%arg" + str(i)] = tensor + ref_tensors["%arg" + str(i)] = tensor.clone() + ref_args.append(ref_tensors["%arg" + str(i)]) + if args[i].scalar: + gc_args.append(tensor.data_ptr()) + else: + gc_args.append(tensor) + + +# ref_out contains return value of the entry +ref_out = runner.ref_run(entry, ref_tensors) + +# we need to swap the result into the args if some arg is the return value +if ref_out is not None: + for i in range(len(ref_out)): + ref_args[0 - i - 1] = ref_out[0 - i - 1] + +mlir_args = get_mlir_args(gc_args) +passes = "any(gc-cpu-pipeline)" + +with module.context: + compiler = GraphCompiler(passes) + engine = compiler.compile_and_jit(module) + engine.invoke("entry", *mlir_args) + +fail, mistrust = False, False +for i in range(len(args)): + # gc_arg contains address for scalar value + # we need to find result by arg name + res = compare_tensor( + args[i], ref_args[i], gc_tensors["%arg" + str(i)], flags.verbose + ) + fail = fail or (not res[0]) + if res[1] is not None: + mistrust = mistrust | res[1] +if fail: + print(f"FAIL: {flags.driver}.{flags.case}") + sys.exit(1) +elif mistrust: + print(f"MISTRUST: {flags.driver}.{flags.case}") +else: + print(f"PASSED: {flags.driver}.{flags.case}") diff --git a/test/benchgc/src/benchgc/arg/CMakeLists.txt b/test/benchgc/src/benchgc/arg/CMakeLists.txt new file mode 100644 index 000000000..614e306da --- /dev/null +++ b/test/benchgc/src/benchgc/arg/CMakeLists.txt @@ -0,0 +1,22 @@ +################################################################################ +# Copyright (C) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# SPDX-License-Identifier: Apache-2.0 +################################################################################ + + +file(GLOB PYTHON_SCRIPTS "*.py") +foreach(PY_SCRIPT ${PYTHON_SCRIPTS}) + configure_file(${PY_SCRIPT} ${CMAKE_BINARY_DIR}/test/benchgc/src/benchgc/arg/ COPYONLY) +endforeach() diff --git a/test/benchgc/src/benchgc/arg/__init__.py b/test/benchgc/src/benchgc/arg/__init__.py new file mode 100644 index 000000000..a2134af2d --- /dev/null +++ b/test/benchgc/src/benchgc/arg/__init__.py @@ -0,0 +1,163 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import argparse +from typing import List, Tuple + +import benchgc.arg.binary as binary +import benchgc.arg.compare +import benchgc.arg.conv as conv +import benchgc.arg.eltwise as eltwise +import benchgc.arg.matmul as matmul +import benchgc.arg.pool as pool +import benchgc.arg.softmax as softmax +import benchgc.util +import torch +from benchgc.arg.arg import Arg + +onednn_module = { + "binary": binary, + "eltwise": eltwise, + "matmul": matmul, + "softmax": softmax, + "conv": conv, + "pool": pool, +} + + +def set_default_fill( + flags: argparse.Namespace, arg: Arg, arglist: List[Arg], is_return: bool +): + if arg.fill_type != "-": + return + + if is_return: + arg.fill_type = "Z" + arg.fill_param = [] + return + + for _, module in onednn_module.items(): + if flags.driver + "." + flags.case in module.op: + module.default_fill(flags, arg, arglist) + return + # use N(0, 1) as default + arg.fill_type = "N" + arg.fill_param = ["0", "1"] + + +def set_default_compare( + flags: argparse.Namespace, arg: Arg, arglist: List[Arg], is_return: bool +): + if arg.cmp_type != "-": + return + + if is_return: + for _, module in onednn_module.items(): + if flags.driver + "." + flags.case in module.op: + module.default_compare(flags, arg, arglist) + return + + dtype: torch.dtype = benchgc.util.get_dtype(arg.dtype) + arg.cmp_type = "P" + if dtype.is_floating_point: + arg.cmp_param = [str(torch.finfo(dtype).eps)] + else: + arg.cmp_param = ["0"] + if is_return: + arg.cmp_param.append("70.0") + else: + arg.cmp_param.append("100.0") + + +def fill_tensor(flags: argparse.Namespace, arg: Arg, idx: int) -> torch.Tensor: + if arg.dtype == "" or arg.fill_type == "": + raise Exception("arg%d filling: dtype/fill_type is not set" % idx) + + # set the seed for the filling + benchgc.util.torch_seed(1, idx) + if arg.fill_type == "N" and len(arg.fill_param) == 2: + # Normal distribution + mean = float(arg.fill_param[0]) + std = float(arg.fill_param[1]) + tensor = torch.normal(mean=mean, std=std, size=arg.shape) + + elif arg.fill_type == "P" and len(arg.fill_param) == 1: + # Poisson distribution + _lambda = float(arg.fill_param[0]) + lambda_tensor = torch.full(arg.shape, _lambda) + tensor = torch.poisson(lambda_tensor) + elif arg.fill_type == "B" and len(arg.fill_param) == 2: + # Binomial distribution + n = int(arg.fill_param[0]) + p = float(arg.fill_param[1]) + bdist = torch.distributions.binomial.Binomial(total_count=n, probs=p) + tensor = bdist.sample(torch.Size(arg.shape)) + elif arg.fill_type == "U" and len(arg.fill_param) == 2: + # Uniform distribution + a = float(arg.fill_param[0]) + b = float(arg.fill_param[1]) + tensor = torch.distributions.uniform.Uniform(a, b).sample(torch.Size(arg.shape)) + elif arg.fill_type == "I" and len(arg.fill_param) == 2: + # integer range + a = int(arg.fill_param[0]) + b = int(arg.fill_param[1]) + tensor = torch.randint(a, b + 1, torch.Size(arg.shape)) + elif arg.fill_type == "F" and len(arg.fill_param) == 1: + # read from pytorch tensor dump file + filename = arg.fill_param[0] + tensor = torch.load(f=filename) + if not isinstance(tensor, torch.Tensor): + raise Exception(f"torch object from file {filename} is not a tensor object") + if tensor.shape != torch.Size(arg.shape): + raise Exception(f"tensor object from file {filename} does not match shape") + if tensor.dtype != benchgc.util.get_dtype(arg.dtype): + raise Exception(f"tensor object from file {filename} does not match dtype") + elif arg.fill_type == "D" and len(arg.fill_param) > 0: + # Driver fill + driver: str = arg.fill_param[0] + driver_module = onednn_module[driver] + tensor = driver_module.fill( + arg.shape, benchgc.util.get_dtype(arg.dtype), arg.fill_param[1:] + ) + elif arg.fill_type == "Z": + tensor = torch.zeros(size=arg.shape) + else: + raise Exception("invalid fill type or fill parameter") + + tensor = tensor.to(benchgc.util.get_dtype(arg.dtype)) + if flags.verbose >= benchgc.util.INPUT_VERBOSE: + print("fill arg%d: " % idx) + print(tensor) + return tensor + + +def compare_tensor( + arg: Arg, ref: torch.Tensor, res: torch.Tensor, verbose: int +) -> Tuple[bool, bool | None]: + + if arg.cmp_type == "P": # p2p check + threshold = float(arg.cmp_param[0]) + zero_percent = float(arg.cmp_param[1]) + return benchgc.arg.compare.p2p(threshold, zero_percent, ref, res, verbose) + if arg.cmp_type == "N": # norm check + threshold = float(arg.cmp_param[0]) + return benchgc.arg.compare.norm(threshold, ref, res, verbose) + elif arg.cmp_type == "D" and len(arg.cmp_param) > 0: # driver check + driver: str = arg.cmp_param[0] + driver_module = onednn_module[driver] + return driver_module.compare(arg.cmp_param[1:], ref, res, verbose) + else: + raise Exception("invalid compare type or compare parameter") diff --git a/test/benchgc/src/benchgc/arg/arg.py b/test/benchgc/src/benchgc/arg/arg.py new file mode 100644 index 000000000..3bf232a93 --- /dev/null +++ b/test/benchgc/src/benchgc/arg/arg.py @@ -0,0 +1,54 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +from typing import List + +import benchgc.mlir.arg +import benchgc.util + + +class Arg(benchgc.mlir.arg.MLIRArg): + fill_type: str + fill_param: List[str] + + cmp_type: str + cmp_param: List[str] + + index: int + + def __init__(self, index: int): + self.dtype = "" + self.fill_type = "-" + self.fill_param = [] + self.cmp_type = "-" + self.cmp_param = [] + self.index = index + + def print_verbose(self, verbose: int): + if verbose >= benchgc.util.ARG_VERBOSE: + print( + f"arg{self.index} shape: {self.shape} dtype: {self.dtype} fill_type: {self.fill_type} fill_param: {self.fill_param} cmp_type: {self.cmp_type} cmp_param: {self.cmp_param}" + ) + + def set_fill(self, fill: str): + splited: List[str] = fill.split(":") + self.fill_type = splited[0] + self.fill_param = splited[1:] + + def set_cmp(self, cmp: str): + splited: List[str] = cmp.split(":") + self.cmp_type = splited[0] + self.cmp_param = splited[1:] diff --git a/test/benchgc/src/benchgc/arg/binary.py b/test/benchgc/src/benchgc/arg/binary.py new file mode 100644 index 000000000..6e2cb0c30 --- /dev/null +++ b/test/benchgc/src/benchgc/arg/binary.py @@ -0,0 +1,98 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import argparse +from typing import Dict, List, Set, Tuple + +import benchgc.util +import torch +from benchgc.arg.arg import Arg +from benchgc.arg.compare import p2p + +# op should use this filling + +op: Set[str] = set( + ["linalg.add", "linalg.div", "linalg.mul", "linalg.max", "linalg.min", "linalg.sub"] +) + +# params format: [src0 | src1, src0 dt, src1 dt, dst dt] + + +def default_fill( + flags: argparse.Namespace, + arg: Arg, + arglist: List[Arg], +): + if arg.index > 1: + raise Exception("binary fill: dst filling is not allowed") + arg.fill_type = "D" + arg.fill_param = [ + "binary", + "src0" if arg.index == 0 else "src1", + arglist[0].dtype, + arglist[1].dtype, + arglist[2].dtype, + ] + + +def fill(shape: List[int], dtype: torch.dtype, params: List[str]) -> torch.Tensor: + name, _, _, _ = params + + accept_name: Dict[str, int] = {"src0": 1, "src1": 2} + if name in accept_name: + arg: int = accept_name[name] + else: + raise Exception("unknown arg name %s", name) + + range_: int = 16 + f_min = 0 if dtype == torch.uint8 else -range_ // 2 + + idx: torch.Tensor = torch.arange( + benchgc.util.nelem(shape), dtype=torch.int + ).reshape(shape) + values: torch.Tensor = (f_min + (12 * idx + 5 * arg + 16) % (range_ + 1)) * 1.25 + if arg == 2: + values = torch.where(values == 0.0, 1, values) + return values.to(dtype=dtype) + + +# compare param: dtype, case + + +def default_compare( + flags: argparse.Namespace, + arg: Arg, + arglist: List[Arg], +): + arg.cmp_type = "D" + arg.cmp_param = ["binary", arg.dtype, flags.case] + + +def compare( + param: List[str], ref: torch.Tensor, res: torch.Tensor, verbose: int +) -> Tuple[bool, bool | None]: + dtype = benchgc.util.get_dtype(param[0]) + + ref = ref.to(torch.float) + res = res.to(torch.float) + + if param[1] in ["div", "div_unsigned"]: + abs_diff = (ref.to(torch.float) - res.to(torch.float)).abs() + init_check = abs_diff < benchgc.util.get_eps(dtype) + else: + init_check = None + + return p2p(benchgc.util.get_eps(dtype), 30.0, ref, res, verbose, init_check) diff --git a/test/benchgc/src/benchgc/arg/compare.py b/test/benchgc/src/benchgc/arg/compare.py new file mode 100644 index 000000000..2e4c31e85 --- /dev/null +++ b/test/benchgc/src/benchgc/arg/compare.py @@ -0,0 +1,137 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +from typing import Callable, List, Tuple + +import benchgc.util +import numpy +import torch + + +def iterate_tensor(tensor: torch.Tensor, fn: Callable[[Tuple[int, ...]], None]): + if tensor.ndim == 0: + fn(tuple()) + return + index: List[int] = [0] * tensor.ndim + + def dfs(depth: int): + if depth == tensor.ndim: + fn(tuple(index)) + else: + for i in range(tensor.shape[depth]): + index[depth] = i + dfs(depth + 1) + + dfs(0) + + +def norm( + threshold: float, ref: torch.Tensor, res: torch.Tensor, verbose: int +) -> Tuple[bool, bool | None]: + + f32_ref = ref.to(torch.float32) + f32_res = res.to(torch.float32) + if f32_ref.nelement() == 0: + return (True, None) + + diff_square_sum = torch.square(torch.subtract(f32_ref, f32_res)).sum() + square_sum = torch.square(f32_ref).sum() + + l2_diff_norm = torch.sqrt(diff_square_sum / square_sum).item() + if verbose >= benchgc.util.COMPARE_VERBOSE: + print(f"norm check: {l2_diff_norm:.10f} / threshold: {threshold:.10f}") + + return (l2_diff_norm < threshold, None) + + +def p2p( + threshold: float, + zero_percent: float, + ref: torch.Tensor, + res: torch.Tensor, + verbose: int, + init_check: torch.Tensor | None = None, +) -> Tuple[bool, bool | None]: + + if verbose >= benchgc.util.COMPARE_VERBOSE: + print(f"p2p check: threshold: {threshold:.7f}") + f32_ref = ref.to(torch.float32) + f32_res = res.to(torch.float32) + + if init_check is None: + check = torch.tensor(False) + else: + check = init_check + + check = check.bitwise_or(torch.bitwise_and(f32_ref.isnan(), f32_res.isnan())) + check = check.bitwise_or(torch.bitwise_and(f32_ref.isneginf(), f32_res.isneginf())) + check = check.bitwise_or(torch.bitwise_and(f32_ref.isposinf(), f32_res.isposinf())) + + # choose diff/rel_diff based on value + abs_diff = (f32_ref - f32_res).abs() + rel_diff = abs_diff / torch.where( + f32_ref.abs() > numpy.finfo(numpy.float32).smallest_subnormal, + f32_ref.abs(), + 1, + ) + # pick a diff for comparison + diff = torch.where(f32_ref.abs() > 1e-5, rel_diff, abs_diff) + + check = check.bitwise_or(diff <= threshold) + + if verbose >= benchgc.util.OUTPUT_VERBOSE: + iterate_tensor( + check, + lambda idx: print( + "%20s: ref: %12.7f res: %12.7f abs_diff: %12.7f rel_diff: %12.7f" + % ( + idx, + f32_ref[idx].item(), + f32_res[idx].item(), + abs_diff[idx].item(), + rel_diff[idx].item(), + ) + ), + ) + if check.all(): + # check mistrusted + zero = res.nelement() - res.count_nonzero().item() + if res.nelement() < 10: + mistrust = False + else: + mistrust = zero * 100.0 / res.nelement() > zero_percent + return (True, mistrust) + else: + if ( + verbose < benchgc.util.OUTPUT_VERBOSE + ): # skip verbose print if full output tensor is alrady printed + fail = torch.argwhere(torch.where(check, 0, 1)) + if verbose < benchgc.util.ERROR_OUTPUT_VERBOSE: + # only print top 10 failed data points if verbose level does not satisfied + fail = fail[:10] + for idx in fail: + index: Tuple[int, ...] = tuple(idx.tolist()) + print( + "%20s: ref: %12.7f res: %12.7f abs_diff: %12.7f rel_diff: %12.7f" + % ( + index, + f32_ref[index].item(), + f32_res[index].item(), + abs_diff[index].item(), + rel_diff[index].item(), + ) + ) + return (False, None) diff --git a/test/benchgc/src/benchgc/arg/conv.py b/test/benchgc/src/benchgc/arg/conv.py new file mode 100644 index 000000000..ba46f201c --- /dev/null +++ b/test/benchgc/src/benchgc/arg/conv.py @@ -0,0 +1,189 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import argparse +from typing import Dict, List, Set, Tuple + +import benchgc.util +import torch +from benchgc.arg.arg import Arg +from benchgc.arg.compare import p2p + +op: Set[str] = set( + [ + "linalg.conv_1d_ncw_fcw", + "linalg.conv_1d_nwc_wcf", + "linalg.conv_1d", + "linalg.conv_2d_nchw_fchw", + "linalg.conv_2d_ngchw_fgchw", + "linalg.conv_2d_ngchw_gfchw", + "linalg.conv_2d_nhwc_fhwc", + "linalg.conv_2d_nhwc_hwcf", + "linalg.conv_2d", + "linalg.conv_3d_ncdhw_fcdhw", + "linalg.conv_3d_ndhwc_dhwcf", + "linalg.conv_3d", + "linalg.depthwise_conv_1d_ncw_cw", + "linalg.depthwise_conv_1d_nwc_wc", + "linalg.depthwise_conv_1d_nwc_wcm", + "linalg.depthwise_conv_2d_nchw_chw", + "linalg.depthwise_conv_2d_nhwc_hwc", + "linalg.depthwise_conv_2d_nhwc_hwcm", + "linalg.depthwise_conv_3d_ncdhw_cdhw", + "linalg.depthwise_conv_3d_ndhwc_dhwc", + "linalg.depthwise_conv_3d_ndhwc_dhwcm", + ] +) + +# params format: [src | wei, src dt, wei dt, dst dt, amp] + + +def default_fill( + flags: argparse.Namespace, + arg: Arg, + arglist: List[Arg], +): + if arg.index > 2: + raise Exception("conv fill: dst filling is not allowed") + arg.fill_type = "D" + arg.fill_param = [ + "conv", + "src" if arg.index == 0 else "wei", + arglist[0].dtype, + arglist[1].dtype, + arglist[2].dtype, + ] + + # find the amplifier of the conv + wei = arglist[1] + nelem = wei.nelem() + if flags.driver == "linalg": + if flags.case in [ + "conv_1d_ncw_fcw", + "conv_2d_nchw_fchw", + "conv_2d_ngchw_fgchw", + "conv_2d_nhwc_fhwc", + "conv_3d_ncdhw_fcdhw", + ]: + arg.fill_param.append(str(nelem // wei.shape[0])) + elif flags.case in ["conv_2d_ngchw_gfchw"]: + arg.fill_param.append(str(nelem // wei.shape[1])) + elif flags.case in [ + "conv_1d_nwc_wcf", + "conv_2d_nhwc_hwcf", + "conv_3d_ndhwc_dhwcf", + "depthwise_conv_1d_nwc_wcm", + "depthwise_conv_2d_nhwc_hwcm", + "depthwise_conv_3d_ndhwc_dhwcm", + ]: + arg.fill_param.append(str(nelem // wei.shape[-1])) + elif flags.case in [ + "conv_1d", + "conv_2d", + "conv_3d", + "depthwise_conv_1d_ncw_cw", + "depthwise_conv_1d_nwc_wc", + "depthwise_conv_2d_nchw_chw", + "depthwise_conv_2d_nhwc_hwc", + "depthwise_conv_3d_ncdhw_cdhw", + "depthwise_conv_3d_ndhwc_dhwc", + ]: + arg.fill_param.append(str(nelem)) + + +def fill(shape: List[int], dtype: torch.dtype, params: List[str]) -> torch.Tensor: + name, src_dt, wei_dt, dst_dt, amp = params + + arg_rng: List[Dict[torch.dtype, Tuple[int, int]]] = [ + { + torch.float32: (-32, 32), + torch.bfloat16: (-4, 4), + torch.float16: (-4, 4), + }, # src + { + torch.float32: (-32, 32), + torch.bfloat16: (-8, 8), + torch.float16: (-2, 2), + }, # wei + ] + + target = torch.empty(size=shape, dtype=torch.float32) + target = target.view(-1) + + src_dt = benchgc.util.get_dtype(src_dt) + wei_dt = benchgc.util.get_dtype(wei_dt) + + src_min, src_max = arg_rng[0][src_dt] + wei_min, wei_max = arg_rng[1][wei_dt] + max_value = max(abs(src_min), abs(src_max)) * max(abs(wei_min), abs(wei_max)) + safe_digits: int = min( + benchgc.util.get_digits("f32"), benchgc.util.get_digits(dst_dt) + ) + safe_n_acc = (1 << safe_digits) // max_value + + if name == "src": + arg_min, arg_max = arg_rng[0][src_dt] + density = 1.0 + elif name == "wei": + arg_min, arg_max = arg_rng[1][wei_dt] + density = min(safe_n_acc / int(amp), 1.0) + else: + raise Exception("unknown arg name %s", name) + + benchgc.util.torch_seed() + + density_t = torch.full(shape, density, dtype=torch.float32) + bernoulli_t = torch.bernoulli(density_t) + condi = density_t == 1 + is_one_t = torch.where(condi, True, bernoulli_t) + gen_value = torch.randint(arg_min, arg_max + 1, size=shape) + target = is_one_t * gen_value + + # make sure the first element is positive + first_val = target.flatten()[0] + if first_val <= 0.0: + while first_val <= 0.0: + first_val = torch.randint(arg_min, arg_max + 1, size=()).item() + target_f = target.view(-1) + target_f[0] = first_val + target = target_f.view(shape) + + return target.to(dtype=dtype) + + +def default_compare( + flags: argparse.Namespace, + arg: Arg, + arglist: List[Arg], +): + arg.cmp_type = "D" + arg.cmp_param = ["conv", arg.dtype, flags.case] + + +def compare( + param: List[str], ref: torch.Tensor, res: torch.Tensor, verbose: int +) -> Tuple[bool, bool | None]: + dtype = benchgc.util.get_dtype(param[0]) + + ref = ref.to(torch.float) + res = res.to(torch.float) + return p2p( + 0.0, # use a relax threshold if using wino + 70.0 if dtype == torch.uint8 else 85.0, + ref, + res, + verbose, + ) diff --git a/test/benchgc/src/benchgc/arg/eltwise.py b/test/benchgc/src/benchgc/arg/eltwise.py new file mode 100644 index 000000000..bf7dbff54 --- /dev/null +++ b/test/benchgc/src/benchgc/arg/eltwise.py @@ -0,0 +1,177 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import argparse +from typing import List, Set, Tuple + +import benchgc.util +import torch +from benchgc.arg.arg import Arg +from benchgc.arg.compare import p2p + +# params format: [alg, alpha, beta] + +# op should use this filling + +op: Set[str] = set( + [ + "linalg.abs", + "linalg.negf", + "linalg.exp", + "linalg.ceil", + "linalg.erf", + "linalg.floor", + "linalg.log", + "linalg.round", + "linalg.rsqrt", + "linalg.sqrt", + "linalg.square", + "linalg.tanh", + ] +) + + +def default_fill( + flags: argparse.Namespace, + arg: Arg, + arglist: List[Arg], +): + if arg.index > 0: + raise Exception("eltwise fill: dst filling is not allowed") + arg.fill_param = ["eltwise", flags.case] + if flags.driver == "linalg" and flags.case in [ + "abs", + "exp", + "ceil", + "erf", + "floor", + "log", + "round", + "sqrt", + "square", + "tanh", + ]: + arg.fill_param.extend(["", ""]) + elif flags.driver == "linalg" and flags.case == "negf": + arg.fill_param.extend(["-1", "0"]) + elif flags.driver == "linalg" and flags.case == "rsqrt": + arg.fill_param.extend(["1", "-0.5"]) + arg.fill_type = "D" + + +def fill(shape: List[int], dtype: torch.dtype, params: List[str]) -> torch.Tensor: + alg, alpha, beta = params + nelems = benchgc.util.nelem(shape) + + float_limit: torch.finfo = torch.finfo(torch.float32) + + alpha = 0.0 if alpha == "" else float(alpha) + beta = 0.0 if beta == "" else float(beta) + + coeff = torch.tensor( + [1, -1, 1, -1, 10.0, -10.0, 10.0, -10.0, 10.0, 10.0, 10.0, 1, 1] + ) + bias = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 88.0, 22.0, 44.0, alpha, beta]) + rand_int_mask = torch.tensor( + [ + True, + True, + False, + False, + True, + True, + False, + False, + False, + False, + False, + False, + False, + ] + ) + rand_uni_mask = torch.tensor( + [ + False, + False, + True, + True, + False, + False, + True, + True, + True, + True, + True, + False, + False, + ] + ) + + if alg == "log": + # append more value for Log validation + coeff = torch.cat((coeff, torch.tensor([1, 1])), dim=0) + bias = torch.cat( + (bias, torch.tensor([float_limit.max, float_limit.min])), dim=0 + ) + rand_int_mask = torch.cat((rand_int_mask, torch.tensor([False, False])), dim=0) + rand_uni_mask = torch.cat((rand_uni_mask, torch.tensor([False, False])), dim=0) + + repeats: int = (nelems + coeff.nelement() - 1) // coeff.nelement() + + coeff = coeff.repeat(repeats)[:nelems] + bias = bias.repeat(repeats)[:nelems] + + rand_int_mask = rand_int_mask.repeat(repeats)[:nelems] + benchgc.util.torch_seed() + rand_int = torch.where(rand_int_mask, torch.randint(0, 10, [nelems]), 0) + + rand_uni_mask = rand_uni_mask.repeat(repeats)[:nelems] + rand_uni = torch.where(rand_uni_mask, torch.rand(nelems) * 0.09, 0) + + value = ((rand_int + rand_uni) * coeff + bias).to(dtype=dtype) + return value.reshape(shape) + + +# param: dtype, case + + +def default_compare( + flags: argparse.Namespace, + arg: Arg, + arglist: List[Arg], +): + arg.cmp_type = "D" + arg.cmp_param = ["eltwise", arg.dtype, flags.case] + + +def compare( + param: List[str], ref: torch.Tensor, res: torch.Tensor, verbose: int +) -> Tuple[bool, bool | None]: + dtype = benchgc.util.get_dtype(param[0]) + ref = ref.to(torch.float) + res = res.to(torch.float) + + threshold = 4e-6 if dtype == torch.float else benchgc.util.get_eps(dtype) + if dtype == torch.float and param[1] in ["tanh", "log"]: + threshold = 4e-5 + + return p2p( + threshold, + 65.0 if dtype.is_floating_point else 100.0, + ref, + res, + verbose, + ) diff --git a/test/benchgc/src/benchgc/arg/matmul.py b/test/benchgc/src/benchgc/arg/matmul.py new file mode 100644 index 000000000..76d0ce81e --- /dev/null +++ b/test/benchgc/src/benchgc/arg/matmul.py @@ -0,0 +1,173 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import argparse +from typing import Dict, List, Set, Tuple + +import benchgc.util +import torch +from benchgc.arg.arg import Arg +from benchgc.arg.compare import p2p + +# params format: [src | wei, src dt, wei dt, dst dt, amp] +# use other filling type for bias + +op: Set[str] = set( + [ + "linalg.batch_matmul", + "linalg.batch_matmul_transpose_a", + "linalg.batch_matmul_transpose_b", + "linalg.batch_matvec", + "linalg.batch_mmt4d", + "linalg.batch_vecmat", + "linalg.batch_reduce_matmul", + "linalg.dot", + "linalg.matmul", + "linalg.matmul_transpose_a", + "linalg.matmul_transpose_b", + "linalg.matvec", + "linalg.mmt4d", + "linalg.vecmat", + ] +) + + +def default_fill( + flags: argparse.Namespace, + arg: Arg, + arglist: List[Arg], +): + if arg.index > 1: + raise Exception("matmul fill: dst filling is not allowed") + arg.fill_type = "D" + arg.fill_param = [ + "matmul", + "src" if arg.index == 0 else "wei", + arglist[0].dtype, + arglist[1].dtype, + arglist[2].dtype, + ] + + # find the amplifier K of the matmul + if flags.driver == "linalg": + if ( + flags.case == "matmul_transpose_b" + or flags.case == "batch_matmul" + and arg.index == 0 + or flags.case == "batch_matmul_transpose_b" + or flags.case == "batch_matvec" + or flags.case == "batch_vecmat" + and arg.index == 0 + or flags.case == "matmul" + and arg.index == 0 + or flags.case == "matvec" + or flags.case == "vecmat" + and arg.index == 0 + or flags.case == "dot" + ): + arg.fill_param.append(str(arg.shape[-1])) + elif ( + flags.case == "batch_matmul" + and arg.index == 1 + or flags.case == "batch_matmul_transpose_a" + or flags.case == "batch_vecmat" + and arg.index == 1 + or flags.case == "matmul" + and arg.index == 1 + or flags.case == "matmul_transpose_a" + or flags.case == "vecmat" + and arg.index == 1 + ): + arg.fill_param.append(str(arg.shape[-2])) + elif flags.case == "batch_mmt4d" or flags.case == "mmt4d": + arg.fill_param.append(str(arg.shape[-1] * arg.shape[-3])) + # reduce the matmul will amplified by B * K + elif flags.case == "batch_reduce_matmul" and arg.index == 0: + arg.fill_param.append(str(arg.shape[-1] * arg.shape[0])) + elif flags.case == "batch_reduce_matmul" and arg.index == 1: + arg.fill_param.append(str(arg.shape[-2] * arg.shape[0])) + + +def fill(shape: List[int], dtype: torch.dtype, params: List[str]) -> torch.Tensor: + name, src_dt, wei_dt, dst_dt, amp = params + + arg_rng: List[Dict[torch.dtype, Tuple[int, int]]] = [ + { + torch.float32: (-64, 64), + torch.bfloat16: (-4, 4), + torch.float16: (-4, 4), + }, # src + { + torch.float32: (-128, 128), + torch.bfloat16: (-8, 8), + torch.float16: (-2, 2), + }, # wei + ] + + src_dt = benchgc.util.get_dtype(src_dt) + wei_dt = benchgc.util.get_dtype(wei_dt) + + src_min, src_max = arg_rng[0][src_dt] + wei_min, wei_max = arg_rng[1][wei_dt] + max_value = max(abs(src_min), abs(src_max)) * max(abs(wei_min), abs(wei_max)) + safe_digits: int = min( + benchgc.util.get_digits("f32"), benchgc.util.get_digits(dst_dt) + ) + + safe_n_acc = (1 << safe_digits) // max_value + + if name == "src": + arg_min, arg_max = arg_rng[0][src_dt] + density = 1.0 + elif name == "wei": + arg_min, arg_max = arg_rng[1][wei_dt] + density = min(safe_n_acc / int(amp), 1.0) + else: + raise Exception("unknown arg name %s", name) + + benchgc.util.torch_seed(1, 0 if name == "src" else 1) + value = torch.bernoulli(torch.full(shape, density)) * torch.randint( + arg_min, arg_max, shape + ) + while value.flatten()[0] <= 0: + value.flatten()[0] = torch.randint(arg_min, arg_max + 1, size=[1])[0].item() + + return value.to(dtype) + + +def default_compare( + flags: argparse.Namespace, + arg: Arg, + arglist: List[Arg], +): + arg.cmp_type = "D" + arg.cmp_param = ["matmul", arg.dtype, flags.case] + + +def compare( + param: List[str], ref: torch.Tensor, res: torch.Tensor, verbose: int +) -> Tuple[bool, bool | None]: + dtype = benchgc.util.get_dtype(param[0]) + + ref = ref.to(torch.float) + res = res.to(torch.float) + return p2p( + 1e-6 if dtype == torch.float else benchgc.util.get_eps(dtype), + 90.0, + ref, + res, + verbose, + ) diff --git a/test/benchgc/src/benchgc/arg/pool.py b/test/benchgc/src/benchgc/arg/pool.py new file mode 100644 index 000000000..7179b7f91 --- /dev/null +++ b/test/benchgc/src/benchgc/arg/pool.py @@ -0,0 +1,97 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import argparse +from typing import List, Set, Tuple + +import benchgc.util +import torch +from benchgc.arg.arg import Arg +from benchgc.arg.compare import p2p + +op: Set[str] = set( + [ + "linalg.pooling_nchw_max", + "linalg.pooling_nchw_sum", + "linalg.pooling_ncw_max", + "linalg.pooling_ncw_sum", + "linalg.pooling_ndhwc_max", + "linalg.pooling_ndhwc_sum", + "linalg.pooling_nhwc_max", + "linalg.pooling_nhwc_sum", + "linalg.pooling_nhwc_min", + "linalg.pooling_nwc_max", + "linalg.pooling_nwc_min", + "linalg.pooling_nwc_sum", + ] +) + + +def default_fill( + flags: argparse.Namespace, + arg: Arg, + arglist: List[Arg], +): + if arg.index > 1: + raise Exception("pool fill: dst filling is not allowed") + arg.fill_type = "D" + arg.fill_param = ["pool"] + + +def fill(shape: List[int], dtype: torch.dtype, params: List[str]) -> torch.Tensor: + arg_rng: Tuple[int, int] = { + torch.float64: (-2048, 2048), + torch.float32: (-2048, 2048), + torch.int32: (-2048, 2048), + torch.bfloat16: (-32, 32), + torch.float16: (-32, 32), + torch.int8: (-128, 127), + torch.uint8: (0, 255), + }[dtype] + + benchgc.util.torch_seed() + target = torch.randint(arg_rng[0], arg_rng[1] + 1, size=[benchgc.util.nelem(shape)]) + # make sure the first element is not negative + if target[0] <= 0.0: + while target[0] <= 0.0: + target[0] = torch.randint(arg_rng[0], arg_rng[1], size=(1,))[0].item() + + return target.reshape(shape).to(dtype=dtype) + + +def default_compare( + flags: argparse.Namespace, + arg: Arg, + arglist: List[Arg], +): + arg.cmp_type = "D" + arg.cmp_param = ["pool", arg.dtype, flags.case] + + +def compare( + param: List[str], ref: torch.Tensor, res: torch.Tensor, verbose: int +) -> Tuple[bool, bool | None]: + dtype = benchgc.util.get_dtype(param[0]) + + ref = ref.to(torch.float) + res = res.to(torch.float) + return p2p( + benchgc.util.get_eps(dtype) * 10.0, + 99.0, + ref, + res, + verbose, + ) diff --git a/test/benchgc/src/benchgc/arg/reduce.py b/test/benchgc/src/benchgc/arg/reduce.py new file mode 100644 index 000000000..bc75e5d84 --- /dev/null +++ b/test/benchgc/src/benchgc/arg/reduce.py @@ -0,0 +1,78 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +from typing import List, Tuple + +import benchgc.arg +import benchgc.util +import torch + + +def fill(shape: List[int], dtype: torch.dtype, params: List[str]) -> torch.Tensor: + + op, sdtype, ddtype, amp = params + + sdtype = benchgc.util.get_dtype(sdtype) + ddtype = benchgc.util.get_dtype(ddtype) + + safe_to_reduce_elems: int = benchgc.util.get_problem_bounds(op, sdtype)[0] + + neutral_value: float = 1.0 if op == "mul" else 0.0 + + shift: float = ( + 1.0 + if ( + op == "mean" + or op == "min" + and not sdtype.is_signed + and not ddtype.is_signed + ) + else 0.0 + ) + + value_range: int = benchgc.util.get_problem_bounds(op, sdtype)[1] + + is_mul_fp: bool = op == "mul" and sdtype.is_floating_point + min_range: int = -value_range if is_mul_fp else 1 + + index = torch.arange(benchgc.util.nelem(shape)).reshape(shape) + + benchgc.util.torch_seed() + value = torch.randint(min_range, value_range + 1, size=shape) + if is_mul_fp: + value = torch.pow(2, value) + if sdtype.is_signed: # random choose positive or negative + value = torch.where(torch.BoolTensor(size=shape), value, -value) + + non_neutral_mask = benchgc.util.flip_coin( + index, + torch.full(shape, safe_to_reduce_elems / int(amp), dtype=torch.float32), + ) + if isinstance(non_neutral_mask, torch.Tensor): + value = torch.where(non_neutral_mask, value, neutral_value) + else: + raise Exception("Flip coin failed when generate the reduce data filling") + value = value + shift + return value.to(dtype) + + +def compare( + ref: torch.Tensor, res: torch.Tensor, verbose: int +) -> Tuple[bool, bool | None]: + dtype = ref.dtype + ref = ref.to(torch.float) + res = res.to(torch.float) + return benchgc.arg.p2p(benchgc.util.get_eps(dtype), 30.0, ref, res, verbose) diff --git a/test/benchgc/src/benchgc/arg/softmax.py b/test/benchgc/src/benchgc/arg/softmax.py new file mode 100644 index 000000000..a9731ec0a --- /dev/null +++ b/test/benchgc/src/benchgc/arg/softmax.py @@ -0,0 +1,93 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import argparse +import operator +from functools import reduce +from typing import List, Set, Tuple + +import benchgc.util +import torch +from benchgc.arg.arg import Arg +from benchgc.arg.compare import p2p + +op: Set[str] = set(["linalg.softmax"]) + + +# params format: [reduce dimension] + + +def default_fill( + flags: argparse.Namespace, + arg: Arg, + arglist: List[Arg], +): + if arg.index > 0: + raise Exception("softmax fill: dst filling is not allowed") + arg.fill_type = "D" + arg.fill_param = ["softmax", str(flags.dimension)] + + +def fill(shape: List[int], dtype: torch.dtype, params: List[str]) -> torch.Tensor: + dimension: int = int(params[0]) + + outer: int = reduce(operator.mul, shape[:dimension], 1) + inner: int = reduce(operator.mul, shape[dimension + 1 :], 1) + benchgc.util.torch_seed() + sign = torch.randint(0, 1, size=[1, shape[dimension], 1]) * 2 - 1 + value = torch.randint(87, 90, size=[outer, shape[dimension], inner]) + value = torch.where(value == 87, 0, value) + value = value * sign + value = torch.where(value == 0, torch.finfo(dtype).min, value) + return value.reshape(shape).to(dtype) + + +# param: dtype, case, reduce size +def default_compare( + flags: argparse.Namespace, + arg: Arg, + arglist: List[Arg], +): + arg.cmp_type = "D" + arg.cmp_param = [ + "softmax", + arg.dtype, + flags.case, + str(arg.shape[int(flags.dimension)]), + ] + + +def compare( + param: List[str], ref: torch.Tensor, res: torch.Tensor, verbose: int +) -> Tuple[bool, bool | None]: + dtype = benchgc.util.get_dtype(param[0]) + ref = ref.to(torch.float) + res = res.to(torch.float) + + reduce_size = int(param[2]) + nzeros = ( + reduce_size - 1 + if dtype == torch.int8 or dtype == torch.uint8 + else max(0, reduce_size - 8) + ) + + return p2p( + benchgc.util.get_eps(dtype) * (5.0 if dtype == torch.float else 1.0), + 100.0 * nzeros / reduce_size, + ref, + res, + verbose, + ) diff --git a/test/benchgc/src/benchgc/arith/CMakeLists.txt b/test/benchgc/src/benchgc/arith/CMakeLists.txt new file mode 100644 index 000000000..63d2bfa79 --- /dev/null +++ b/test/benchgc/src/benchgc/arith/CMakeLists.txt @@ -0,0 +1,22 @@ +################################################################################ +# Copyright (C) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# SPDX-License-Identifier: Apache-2.0 +################################################################################ + + +file(GLOB PYTHON_SCRIPTS "*.py") +foreach(PY_SCRIPT ${PYTHON_SCRIPTS}) + configure_file(${PY_SCRIPT} ${CMAKE_BINARY_DIR}/test/benchgc/src/benchgc/arith COPYONLY) +endforeach() diff --git a/test/benchgc/src/benchgc/arith/__init__.py b/test/benchgc/src/benchgc/arith/__init__.py new file mode 100644 index 000000000..a5f942a72 --- /dev/null +++ b/test/benchgc/src/benchgc/arith/__init__.py @@ -0,0 +1,45 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import argparse +import importlib +from typing import Callable, Dict, List, Tuple + +import gc_mlir.ir +import torch +from benchgc.arg import Arg +from benchgc.mlir.util import MLIRCache + +ref_op: Dict[ + str, + Callable[ + [MLIRCache, gc_mlir.ir.OpView, Dict[str, torch.Tensor]], + Tuple[torch.Tensor, ...], + ], +] = {} +mlir_op: Dict[ + str, Callable[[argparse.Namespace, List[Arg], List[Arg]], gc_mlir.ir.Module] +] = {} + +for dri in ["basic"]: + mod = importlib.import_module(f"benchgc.arith.{dri}") + for key in mod.__dict__: + if key.startswith("ref_"): + op: str = key.removeprefix("ref_") + ref_op[op] = mod.__dict__[key] + if key.startswith("mlir_"): + op: str = key.removeprefix("mlir_") + mlir_op[op] = mod.__dict__[key] diff --git a/test/benchgc/src/benchgc/arith/basic.py b/test/benchgc/src/benchgc/arith/basic.py new file mode 100644 index 000000000..7e4b17467 --- /dev/null +++ b/test/benchgc/src/benchgc/arith/basic.py @@ -0,0 +1,58 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +from typing import Dict, Tuple + +import benchgc.util +import gc_mlir._mlir_libs._mlir.ir +import gc_mlir.ir +import torch +from benchgc.mlir.util import MLIRCache + + +def ref_constant( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + value = op.attributes["value"] + if isinstance(value, gc_mlir._mlir_libs._mlir.ir.FloatAttr): + return ( + torch.full(size=tuple(), fill_value=value.__float__(), dtype=torch.float), + ) + elif isinstance(value, gc_mlir._mlir_libs._mlir.ir.DenseFPElementsAttr): + if value.is_splat: + return ( + torch.full( + size=tuple(value.type.shape), + fill_value=value.get_splat_value().value, + dtype=benchgc.util.get_dtype(str(value.get_splat_value().type)), + ), + ) + else: + raise Exception("only support splat value now") + else: + raise Exception("Not support constant type %s", type(value)) + + +def ref_mulf( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (var[cache.opr[0]] * var[cache.opr[1]],) + + +def ref_addf( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (var[cache.opr[0]] + var[cache.opr[1]],) diff --git a/test/benchgc/src/benchgc/linalg/CMakeLists.txt b/test/benchgc/src/benchgc/linalg/CMakeLists.txt new file mode 100644 index 000000000..8daf7848a --- /dev/null +++ b/test/benchgc/src/benchgc/linalg/CMakeLists.txt @@ -0,0 +1,22 @@ +################################################################################ +# Copyright (C) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# SPDX-License-Identifier: Apache-2.0 +################################################################################ + + +file(GLOB PYTHON_SCRIPTS "*.py") +foreach(PY_SCRIPT ${PYTHON_SCRIPTS}) + configure_file(${PY_SCRIPT} ${CMAKE_BINARY_DIR}/test/benchgc/src/benchgc/linalg/ COPYONLY) +endforeach() diff --git a/test/benchgc/src/benchgc/linalg/__init__.py b/test/benchgc/src/benchgc/linalg/__init__.py new file mode 100644 index 000000000..331bd75dd --- /dev/null +++ b/test/benchgc/src/benchgc/linalg/__init__.py @@ -0,0 +1,52 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import argparse +import importlib +from typing import Callable, Dict, List, Tuple + +import gc_mlir.ir +import torch +from benchgc.arg import Arg +from benchgc.mlir.util import MLIRCache + +ref_op: Dict[ + str, + Callable[ + [MLIRCache, gc_mlir.ir.OpView, Dict[str, torch.Tensor]], + Tuple[torch.Tensor, ...], + ], +] = {} +mlir_op: Dict[str, Callable[[argparse.Namespace, List[Arg]], gc_mlir.ir.Module]] = {} + +for dri in [ + "binary", + "matmul", + "eltwise", + "misc", + "generic", + "softmax", + "conv", + "pool", +]: + mod = importlib.import_module(f"benchgc.linalg.{dri}") + for key in mod.__dict__: + if key.startswith("ref_"): + op: str = key.removeprefix("ref_") + ref_op[op] = mod.__dict__[key] + if key.startswith("mlir_"): + op: str = key.removeprefix("mlir_") + mlir_op[op] = mod.__dict__[key] diff --git a/test/benchgc/src/benchgc/linalg/binary.py b/test/benchgc/src/benchgc/linalg/binary.py new file mode 100644 index 000000000..ed5d280a3 --- /dev/null +++ b/test/benchgc/src/benchgc/linalg/binary.py @@ -0,0 +1,137 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import argparse +from typing import Dict, List, Tuple + +import gc_mlir.ir +import torch +from benchgc.arg import Arg +from benchgc.mlir.module import init_module +from benchgc.mlir.util import MLIRCache +from gc_mlir.dialects import linalg + + +def ref_add( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.add(var[cache.opr[0]], var[cache.opr[1]]),) + + +def mlir_add(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.add(arg0, arg1, outs=[args[2].get_zero_op(ctx)]) + ], + ) + + +def ref_powf( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.pow(var[cache.opr[0]], var[cache.opr[1]]),) + + +def mlir_powf(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.powf(arg0, arg1, outs=[args[2].get_zero_op(ctx)]) + ], + ) + + +def ref_div( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.div(var[cache.opr[0]], var[cache.opr[1]]),) + + +def mlir_div(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.div(arg0, arg1, outs=[args[2].get_zero_op(ctx)]) + ], + ) + + +def ref_max( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.max(var[cache.opr[0]], var[cache.opr[1]]),) + + +def mlir_max(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.max(arg0, arg1, outs=[args[2].get_zero_op(ctx)]) + ], + ) + + +def ref_min( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.min(var[cache.opr[0]], var[cache.opr[1]]),) + + +def mlir_min(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.min(arg0, arg1, outs=[args[2].get_zero_op(ctx)]) + ], + ) + + +def ref_mul( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.mul(var[cache.opr[0]], var[cache.opr[1]]),) + + +def mlir_mul(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.mul(arg0, arg1, outs=[args[2].get_zero_op(ctx)]) + ], + ) + + +def ref_sub( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.sub(var[cache.opr[0]], var[cache.opr[1]]),) + + +def mlir_sub(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.sub(arg0, arg1, outs=[args[2].get_zero_op(ctx)]) + ], + ) diff --git a/test/benchgc/src/benchgc/linalg/conv.py b/test/benchgc/src/benchgc/linalg/conv.py new file mode 100644 index 000000000..c8fc38efb --- /dev/null +++ b/test/benchgc/src/benchgc/linalg/conv.py @@ -0,0 +1,834 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import argparse +from typing import Dict, List, Tuple + +import gc_mlir.ir +import torch +from benchgc.arg import Arg +from benchgc.mlir.module import init_module +from benchgc.mlir.util import MLIRCache +from gc_mlir.dialects import linalg + + +def ref_conv_1d_ncw_fcw( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + return ( + torch.conv1d( + var[cache.opr[0]], + var[cache.opr[1]], + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + ), + ) + + +def mlir_conv_1d_ncw_fcw( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.conv_1d_ncw_fcw( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_conv_1d_nwc_wcf( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + + # src: nwc -> ncw + # wei: wcf -> fcw + # dst: nwf -> nfw + + return ( + torch.conv1d( + var[cache.opr[0]].permute([0, 2, 1]), + var[cache.opr[1]].permute([2, 1, 0]), + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + ) + .permute([0, 2, 1]) + .contiguous(), + ) + + +def mlir_conv_1d_nwc_wcf( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.conv_1d_nwc_wcf( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_conv_1d_ncw_fcw( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + return ( + torch.conv1d( + var[cache.opr[0]], + var[cache.opr[1]], + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + ), + ) + + +def mlir_conv_1d_ncw_fcw( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.conv_1d_ncw_fcw( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_conv_1d( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return ( + torch.conv1d( + var[cache.opr[0]].unsqueeze(0).unsqueeze(0), + var[cache.opr[1]].unsqueeze(0).unsqueeze(0), + ) + .squeeze(0) + .squeeze(0), + ) + + +def mlir_conv_1d(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.conv_1d( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + ) + ], + ) + + +def ref_conv_2d_nchw_fchw( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + return ( + torch.conv2d( + var[cache.opr[0]], + var[cache.opr[1]], + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + ), + ) + + +def mlir_conv_2d_nchw_fchw( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.conv_2d_nchw_fchw( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_conv_2d_ngchw_fgchw( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + + src = var[cache.opr[0]] + wei = var[cache.opr[1]] + groups: int = src.shape[1] + + dst = torch.conv2d( + src.reshape( + [src.shape[0], src.shape[1] * src.shape[2], src.shape[3], src.shape[4]] + ), # merge group axis with channel + wei.transpose(0, 1) + .contiguous() + .reshape( + [wei.shape[0] * wei.shape[1], wei.shape[2], wei.shape[3], wei.shape[4]] + ), # merge group axis with output channel + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + groups=groups, + ) + return ( + dst.reshape( + [dst.shape[0], groups, dst.shape[1] // groups, dst.shape[2], dst.shape[3]] + ), + ) # split group axis from output channel + + +def mlir_conv_2d_ngchw_fgchw( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.conv_2d_ngchw_fgchw( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_conv_2d_ngchw_gfchw( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + + src = var[cache.opr[0]] + wei = var[cache.opr[1]] + groups: int = src.shape[1] + + dst = torch.conv2d( + src.reshape( + [src.shape[0], src.shape[1] * src.shape[2], src.shape[3], src.shape[4]] + ), # merge group axis with channel + wei.reshape( + [wei.shape[0] * wei.shape[1], wei.shape[2], wei.shape[3], wei.shape[4]] + ), # merge group axis with output channel + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + groups=groups, + ) + return ( + dst.reshape( + [dst.shape[0], groups, dst.shape[1] // groups, dst.shape[2], dst.shape[3]] + ), + ) # split group axis from output channel + + +def mlir_conv_2d_ngchw_gfchw( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.conv_2d_ngchw_gfchw( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_conv_2d_nhwc_fhwc( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + return ( + torch.conv2d( + var[cache.opr[0]].permute([0, 3, 1, 2]), + var[cache.opr[1]].permute([0, 3, 1, 2]), + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + ) + .permute([0, 2, 3, 1]) + .contiguous(), + ) + + +def mlir_conv_2d_nhwc_fhwc( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.conv_2d_nhwc_fhwc( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_conv_2d_nhwc_hwcf( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + return ( + torch.conv2d( + var[cache.opr[0]].permute([0, 3, 1, 2]), + var[cache.opr[1]].permute([3, 2, 0, 1]), + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + ) + .permute([0, 2, 3, 1]) + .contiguous(), + ) + + +def mlir_conv_2d_nhwc_hwcf( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.conv_2d_nhwc_hwcf( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_conv_2d( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return ( + torch.conv2d( + var[cache.opr[0]].unsqueeze(0).unsqueeze(0), + var[cache.opr[1]].unsqueeze(0).unsqueeze(0), + ) + .squeeze(0) + .squeeze(0), + ) + + +def mlir_conv_2d(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.conv_2d( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + ) + ], + ) + + +def ref_conv_3d_ncdhw_fcdhw( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + return ( + torch.conv3d( + var[cache.opr[0]], + var[cache.opr[1]], + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + ), + ) + + +def mlir_conv_3d_ncdhw_fcdhw( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.conv_3d_ncdhw_fcdhw( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_conv_3d_ndhwc_dhwcf( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + return ( + torch.conv3d( + var[cache.opr[0]].permute([0, 4, 1, 2, 3]), + var[cache.opr[1]].permute([4, 3, 0, 1, 2]), + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + ) + .permute([0, 2, 3, 4, 1]) + .contiguous(), + ) + + +def mlir_conv_3d_ndhwc_dhwcf( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.conv_3d_ndhwc_dhwcf( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_conv_3d( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return ( + torch.conv3d( + var[cache.opr[0]].unsqueeze(0).unsqueeze(0), + var[cache.opr[1]].unsqueeze(0).unsqueeze(0), + ) + .squeeze(0) + .squeeze(0), + ) + + +def mlir_conv_3d(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.conv_3d( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + ) + ], + ) + + +def ref_depthwise_conv_1d_ncw_cw( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + groups: int = var[cache.opr[0]].shape[1] + return ( + torch.conv1d( + var[cache.opr[0]], + var[cache.opr[1]].unsqueeze(1), + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + groups=groups, + ), + ) + + +def mlir_depthwise_conv_1d_ncw_cw( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.depthwise_conv_1d_ncw_cw( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_depthwise_conv_1d_nwc_wc( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + groups: int = var[cache.opr[0]].shape[-1] + return ( + torch.conv1d( + var[cache.opr[0]].transpose(-1, -2), + var[cache.opr[1]].transpose(-1, -2).unsqueeze(1), + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + groups=groups, + ) + .transpose(-1, -2) + .contiguous(), + ) + + +def mlir_depthwise_conv_1d_nwc_wc( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.depthwise_conv_1d_nwc_wc( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_depthwise_conv_1d_nwc_wcm( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + src = var[cache.opr[0]] + groups: int = src.shape[-1] + wei = var[cache.opr[1]] + dst = ( + torch.conv1d( + src.transpose(-1, -2), + wei.reshape([wei.shape[0], wei.shape[1] * wei.shape[2]]) + .transpose(-1, -2) + .unsqueeze(1), + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + groups=groups, + ) + .transpose(-1, -2) + .contiguous() + ) + return (dst.reshape([dst.shape[0], dst.shape[1], wei.shape[1], wei.shape[2]]),) + + +def mlir_depthwise_conv_1d_nwc_wcm( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.depthwise_conv_1d_nwc_wcm( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_depthwise_conv_2d_nchw_chw( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + groups: int = var[cache.opr[0]].shape[1] + return ( + torch.conv2d( + var[cache.opr[0]], + var[cache.opr[1]].unsqueeze(1), + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + groups=groups, + ), + ) + + +def mlir_depthwise_conv_2d_nchw_chw( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.depthwise_conv_2d_nchw_chw( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_depthwise_conv_2d_nhwc_hwc( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + groups: int = var[cache.opr[0]].shape[-1] + return ( + torch.conv2d( + var[cache.opr[0]].permute([0, 3, 1, 2]), + var[cache.opr[1]].permute([2, 0, 1]).unsqueeze(1), + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + groups=groups, + ) + .permute([0, 2, 3, 1]) + .contiguous(), + ) + + +def mlir_depthwise_conv_2d_nhwc_hwc( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.depthwise_conv_2d_nhwc_hwc( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_depthwise_conv_2d_nhwc_hwcm( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + groups: int = var[cache.opr[0]].shape[-1] + wei = var[cache.opr[1]] + dst = ( + torch.conv2d( + var[cache.opr[0]].permute([0, 3, 1, 2]), + wei.reshape([wei.shape[0], wei.shape[1], wei.shape[2] * wei.shape[3]]) + .permute([2, 0, 1]) + .unsqueeze(1), + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + groups=groups, + ) + .permute([0, 2, 3, 1]) + .contiguous() + ) + return ( + dst.reshape( + [dst.shape[0], dst.shape[1], dst.shape[2], wei.shape[-2], wei.shape[-1]] + ), + ) + + +def mlir_depthwise_conv_2d_nhwc_hwcm( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.depthwise_conv_2d_nhwc_hwcm( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_depthwise_conv_3d_ncdhw_cdhw( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + groups: int = var[cache.opr[0]].shape[1] + return ( + torch.conv3d( + var[cache.opr[0]], + var[cache.opr[1]].unsqueeze(1), + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + groups=groups, + ), + ) + + +def mlir_depthwise_conv_3d_ncdhw_cdhw( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.depthwise_conv_3d_ncdhw_cdhw( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_depthwise_conv_3d_ndhwc_dhwc( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + groups: int = var[cache.opr[0]].shape[-1] + return ( + torch.conv3d( + var[cache.opr[0]].permute([0, 4, 1, 2, 3]), + var[cache.opr[1]].permute([3, 0, 1, 2]).unsqueeze(1), + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + groups=groups, + ) + .permute([0, 2, 3, 4, 1]) + .contiguous(), + ) + + +def mlir_depthwise_conv_3d_ndhwc_dhwc( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.depthwise_conv_3d_ndhwc_dhwc( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_depthwise_conv_3d_ndhwc_dhwcm( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + groups: int = var[cache.opr[0]].shape[-1] + wei = var[cache.opr[1]] + dst = ( + torch.conv3d( + var[cache.opr[0]].permute([0, 4, 1, 2, 3]), + wei.reshape( + [wei.shape[0], wei.shape[1], wei.shape[2], wei.shape[3] * wei.shape[4]] + ) + .permute([3, 0, 1, 2]) + .unsqueeze(1), + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + groups=groups, + ) + .permute([0, 2, 3, 4, 1]) + .contiguous() + ) + return ( + dst.reshape( + [ + dst.shape[0], + dst.shape[1], + dst.shape[2], + dst.shape[3], + wei.shape[-2], + wei.shape[-1], + ] + ), + ) + + +def mlir_depthwise_conv_3d_ndhwc_dhwcm( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.depthwise_conv_3d_ndhwc_dhwcm( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) diff --git a/test/benchgc/src/benchgc/linalg/eltwise.py b/test/benchgc/src/benchgc/linalg/eltwise.py new file mode 100644 index 000000000..7ae9b31b7 --- /dev/null +++ b/test/benchgc/src/benchgc/linalg/eltwise.py @@ -0,0 +1,197 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import argparse +from typing import Dict, List, Tuple + +import gc_mlir.ir +import torch +from benchgc.arg import Arg +from benchgc.mlir.module import init_module +from benchgc.mlir.util import MLIRCache +from gc_mlir.dialects import linalg + + +def ref_abs( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.abs(var[cache.opr[0]]),) + + +def mlir_abs(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0],), + (args[1],), + lambda ctx, arg0: [linalg.abs(arg0, outs=[args[1].get_zero_op(ctx)])], + ) + + +def ref_ceil( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.ceil(var[cache.opr[0]]),) + + +def mlir_ceil(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0],), + (args[1],), + lambda ctx, arg0: [linalg.ceil(arg0, outs=[args[1].get_zero_op(ctx)])], + ) + + +def ref_floor( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.floor(var[cache.opr[0]]),) + + +def mlir_floor(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0],), + (args[1],), + lambda ctx, arg0: [linalg.floor(arg0, outs=[args[1].get_zero_op(ctx)])], + ) + + +def ref_erf( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.erf(var[cache.opr[0]]),) + + +def mlir_erf(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0],), + (args[1],), + lambda ctx, arg0: [linalg.erf(arg0, outs=[args[1].get_zero_op(ctx)])], + ) + + +def mlir_log(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0],), + (args[1],), + lambda ctx, arg0: [linalg.log(arg0, outs=[args[1].get_zero_op(ctx)])], + ) + + +def ref_log( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.log(var[cache.opr[0]]),) + + +def mlir_negf(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0],), + (args[1],), + lambda ctx, arg0: [linalg.negf(arg0, outs=[args[1].get_zero_op(ctx)])], + ) + + +def ref_negf( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.neg(var[cache.opr[0]]),) + + +def ref_exp( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.exp(var[cache.opr[0]]),) + + +def mlir_exp(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0],), + (args[1],), + lambda ctx, arg0: [linalg.negf(arg0, outs=[args[1].get_zero_op(ctx)])], + ) + + +def ref_round( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + # torch.round is following the priciple "round half to even" + # we need another implementation + + v = torch.floor(var[cache.opr[0]]) + return (v + torch.where(var[cache.opr[0]] - v >= 0.5, 1, 0),) + + +def mlir_round(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0],), + (args[1],), + lambda ctx, arg0: [linalg.round(arg0, outs=[args[1].get_zero_op(ctx)])], + ) + + +def ref_rsqrt( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.rsqrt(var[cache.opr[0]]),) + + +def mlir_rsqrt(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0],), + (args[1],), + lambda ctx, arg0: [linalg.rsqrt(arg0, outs=[args[1].get_zero_op(ctx)])], + ) + + +def ref_sqrt( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.sqrt(var[cache.opr[0]]),) + + +def mlir_sqrt(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0],), + (args[1],), + lambda ctx, arg0: [linalg.sqrt(arg0, outs=[args[1].get_zero_op(ctx)])], + ) + + +def ref_square( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.square(var[cache.opr[0]]),) + + +def mlir_square(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0],), + (args[1],), + lambda ctx, arg0: [linalg.square(arg0, outs=[args[1].get_zero_op(ctx)])], + ) + + +def ref_tanh( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.tanh(var[cache.opr[0]]),) + + +def mlir_tanh(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0],), + (args[1],), + lambda ctx, arg0: [linalg.tanh(arg0, outs=[args[1].get_zero_op(ctx)])], + ) diff --git a/test/benchgc/src/benchgc/linalg/generic.py b/test/benchgc/src/benchgc/linalg/generic.py new file mode 100644 index 000000000..67228ab47 --- /dev/null +++ b/test/benchgc/src/benchgc/linalg/generic.py @@ -0,0 +1,236 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +from typing import Any, Dict, List, Tuple + +import benchgc.runner +import benchgc.util +import gc_mlir.ir +import torch +from benchgc.mlir.util import MLIRCache + + +def generic_loop( + cache: MLIRCache, + op: gc_mlir.ir.OpView, + depth: int, + iterspace: Dict[str, Tuple[int, int, int]], + affine_from: List[str], + affine_to: List[List[str]], + var: Dict[str, torch.Tensor], + loop_var: Dict[str, torch.Tensor], + result_tensors: Tuple[torch.Tensor, ...], +): + if depth == len(affine_from): + # we need to execute the block here + # we will need to read the block argument name and save it into the cache + + if len(cache.next) == 0: + # region cache + cache.next.append(MLIRCache()) + + block: gc_mlir.ir.Block = op.regions[0].blocks[0] + if len(cache.next[0].next) == 0: + # region->block cache + cache.next[0].next.append(MLIRCache()) + for arg in block.arguments: + cache.next[0].next[0].arg.append(arg.get_name()) + + block_cache = cache.next[0].next[0] + block_arg: Dict[str, torch.Tensor] = {} + for i in range(len(block.arguments)): + index: Tuple[int, ...] = tuple() + aff: List[str] = affine_to[i] + for d in aff: + index = index + (int(loop_var[d].item()),) + + if i + len(op.results) < len(op.regions[0].blocks[0].arguments): + # input argument + block_arg[block_cache.arg[i]] = var[cache.opr[i]][index] + else: + # output argument + block_arg[block_cache.arg[i]] = result_tensors[ + i + len(op.results) - len(block.arguments) + ][index] + + res: Tuple[Any, ...] = benchgc.runner.dfs_block( + cache.next[0].next[0], block, var | loop_var | block_arg + ) + + # perform the yield operation + for i in range(len(op.results)): + idx = -1 - i + aff: List[str] = affine_to[idx] + index: Tuple[int, ...] = tuple() + for d in aff: + index = index + (int(loop_var[d].item()),) + result_tensors[idx][index] = res[idx] + else: + it = iterspace[affine_from[depth]] + for i in range(it[0], it[1], it[2]): + loop_var[affine_from[depth]][0] = i + generic_loop( + cache, + op, + depth + 1, + iterspace, + affine_from, + affine_to, + var, + loop_var, + result_tensors, + ) + + +def ref_generic( + cache: MLIRCache, op: gc_mlir.ir.OpView, tensors: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + affine_from: List[str] = [] + affine_to: List[List[str]] = [] + + for affine in op.attributes["indexing_maps"]: + aff = str(affine) + affine_from = aff[aff.find("<(") + 2 : aff.find(") ->")].split(", ") + affine_to.append(aff[aff.find("-> (") + 4 : aff.find(")>")].split(", ")) + + # try to find the iteration space + # TODO: support affine expression + + iterspace: Dict[str, Tuple[int, int, int]] = {} + operands: List[gc_mlir.ir.OpOperand] = list(op.operands) + + loop_var: Dict[str, torch.Tensor] = {} + for d in affine_from: + iterspace[d] = (0, 0, 1) + loop_var[d] = torch.zeros(size=[1], dtype=torch.int) + + for i in range(len(operands)): + for j in range(len(operands[i].type.shape)): + iterspace[affine_to[i][j]] = (0, operands[i].type.shape[j], 1) + + result_tensors: Tuple[torch.Tensor, ...] = tuple() + # create the buffer for result tensors + for i in range(len(op.results)): + result_tensors = result_tensors + (tensors[cache.opr[-1 - i]].clone(),) + + generic_loop( + cache, + op, + 0, + iterspace, + affine_from, + affine_to, + tensors, + loop_var, + result_tensors, + ) + return result_tensors + + +def reduce_loop( + cache: MLIRCache, + op: gc_mlir.ir.OpView, + depth: int, + in_shape: List[int], + var: Dict[str, torch.Tensor], + in_idx: List[int], + out_idx: List[int], + reduced_axis: int, + result_tensor: torch.Tensor, +): + if depth == len(in_shape): + # we need to execute the block here + # we will need to read the block argument name and save it into the cache + + block: gc_mlir.ir.Block = op.regions[0].blocks[0] + + if len(cache.next) == 0: + # region cache + cache.next.append(MLIRCache()) + if len(cache.next[0].next) == 0: + # region->block cache + cache.next[0].next.append(MLIRCache()) + for arg in block.arguments: + cache.next[0].next[0].arg.append(arg.get_name()) + + block_arg: Dict[str, torch.Tensor] = { + # set input + cache.next[0].next[0].arg[0]: var[cache.opr[0]][tuple(in_idx)], + # set output + cache.next[0].next[0].arg[1]: result_tensor[tuple(out_idx)], + } + + res: Tuple[torch.Tensor, ...] = benchgc.runner.dfs_block( + cache.next[0].next[0], op.regions[0].blocks[0], var | block_arg + ) + + # perform the yield operation + result_tensor[tuple(out_idx)] = res[0] + else: + dimensions: gc_mlir.ir.DenseI64ArrayAttr = op.attributes["dimensions"] + reduce_axis: bool = depth in list(dimensions) + + for i in range(in_shape[depth]): + if reduce_axis: + in_idx[depth] = i + reduce_loop( + cache, + op, + depth + 1, + in_shape, + var, + in_idx, + out_idx, + reduced_axis + 1, + result_tensor, + ) + else: + in_idx[depth] = i + out_idx[depth - reduced_axis] = i + reduce_loop( + cache, + op, + depth + 1, + in_shape, + var, + in_idx, + out_idx, + reduced_axis, + result_tensor, + ) + + +def ref_reduce( + cache: MLIRCache, op: gc_mlir.ir.OpView, tensors: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + # create the buffer for result tensors + tensors[cache.res[0]] = tensors[cache.opr[-1]].clone() + in_shape: List[int] = list(op.operands[0].type.shape) + out_shape: List[int] = list(op.result.type.shape) + + result_tensor: torch.Tensor = tensors[cache.opr[-1]].clone() + reduce_loop( + cache, + op, + 0, + in_shape, + tensors, + [0] * len(in_shape), + [0] * len(out_shape), + 0, + result_tensor, + ) + return (result_tensor,) diff --git a/test/benchgc/src/benchgc/linalg/matmul.py b/test/benchgc/src/benchgc/linalg/matmul.py new file mode 100644 index 000000000..9efde9612 --- /dev/null +++ b/test/benchgc/src/benchgc/linalg/matmul.py @@ -0,0 +1,317 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import argparse +from typing import Dict, List, Tuple + +import gc_mlir.ir +import torch +from benchgc.arg import Arg +from benchgc.mlir.module import init_module +from benchgc.mlir.util import MLIRCache +from gc_mlir.dialects import linalg +from gc_mlir.dialects.linalg.opdsl.lang.comprehension import TypeFnType + + +def ref_batch_matmul( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.matmul(var[cache.opr[0]], var[cache.opr[1]]),) + + +def mlir_batch_matmul(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.batch_matmul(arg0, arg1, outs=[args[2].get_zero_op(ctx)]) + ], + ) + + +def ref_batch_matmul_transpose_a( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.bmm(var[cache.opr[0]].transpose(-1, -2), var[cache.opr[1]]),) + + +def mlir_batch_matmul_transpose_a( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.batch_matmul_transpose_a(arg0, arg1, outs=[args[2].get_zero_op(ctx)]) + ], + ) + + +def ref_batch_matmul_transpose_b( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.bmm(var[cache.opr[0]], var[cache.opr[1]].transpose(-1, -2)),) + + +def mlir_batch_matmul_transpose_b( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.batch_matmul_transpose_b(arg0, arg1, outs=[args[2].get_zero_op(ctx)]) + ], + ) + + +def ref_batch_matvec( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + # pytorch does not support bmv + return ( + torch.matmul(var[cache.opr[0]], var[cache.opr[1]].unsqueeze(-1)).squeeze(-1), + ) + + +def mlir_batch_matvec(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.batch_matvec(arg0, arg1, outs=[args[2].get_zero_op(ctx)]) + ], + ) + + +def ref_batch_mmt4d( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + # [B, m, k, m0, k0] -> [B, m, m0, k, k0] + _src = var[cache.opr[0]].permute([0, 1, 3, 2, 4]).contiguous() + # [B, n, k, n0, k0] -> [B, k, k0, n, n0] + _wei = var[cache.opr[1]].permute([0, 2, 4, 1, 3]).contiguous() + + # [B, m, m0, k, k0] -> [B, M, K] + src = _src.reshape( + [_src.shape[0], _src.shape[1] * _src.shape[2], _src.shape[3] * _src.shape[4]] + ) + # [B, k, k0, n, n0] -> [B, K, N] + wei = _wei.reshape( + [_wei.shape[0], _wei.shape[1] * _wei.shape[2], _wei.shape[3] * _wei.shape[4]] + ) + + dst = torch.bmm(src, wei) + # [B, M, N] -> [B, m, m0, n, n0] + dst = dst.reshape( + [dst.shape[0], _src.shape[1], _src.shape[2], _wei.shape[-2], _wei.shape[-1]] + ) + + # [B, m, m0, n, n0] -> [B, m, n, m0, n0] + return (dst.transpose(2, 3).contiguous(),) + + +def mlir_batch_mmt4d(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.batch_mmt4d(arg0, arg1, outs=[args[2].get_zero_op(ctx)]) + ], + ) + + +def ref_batch_reduce_matmul( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return ( + torch.addbmm( + input=torch.zeros(tuple()), + batch1=var[cache.opr[0]], + batch2=var[cache.opr[1]], + beta=0, + alpha=1, + ), + ) + + +def mlir_batch_reduce_matmul( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.batch_reduce_matmul(arg0, arg1, outs=[args[2].get_zero_op(ctx)]) + ], + ) + + +def ref_batch_vecmat( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return ( + torch.matmul(var[cache.opr[0]].unsqueeze(-2), var[cache.opr[1]]).squeeze(-2), + ) + + +def mlir_batch_vecmat(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.batch_vecmat(arg0, arg1, outs=[args[2].get_zero_op(ctx)]) + ], + ) + + +def ref_dot( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.dot(var[cache.opr[0]], var[cache.opr[1]]),) + + +def mlir_dot(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.dot(arg0, arg1, outs=[args[2].get_zero_op(ctx)]) + ], + ) + + +def ref_matmul( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.mm(var[cache.opr[0]], var[cache.opr[1]]),) + + +def mlir_matmul(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.matmul( + arg0, arg1, outs=[args[2].get_zero_op(ctx)], cast=TypeFnType(flags.cast) + ) + ], + ) + + +def ref_matmul_transpose_a( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.mm(var[cache.opr[0]].transpose(-1, -2), var[cache.opr[1]]),) + + +def mlir_matmul_transpose_a( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.matmul_transpose_a( + arg0, arg1, outs=[args[2].get_zero_op(ctx)], cast=TypeFnType(flags.cast) + ) + ], + ) + + +def ref_matmul_transpose_b( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.mm(var[cache.opr[0]], var[cache.opr[1]].transpose(-1, -2)),) + + +def mlir_matmul_transpose_b( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.matmul_transpose_b( + arg0, arg1, outs=[args[2].get_zero_op(ctx)], cast=TypeFnType(flags.cast) + ) + ], + ) + + +def ref_matvec( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.mv(var[cache.opr[0]], var[cache.opr[1]]),) + + +def mlir_matvec(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.matvec(arg0, arg1, outs=[args[2].get_zero_op(ctx)]) + ], + ) + + +def ref_mmt4d( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + # [m, k, m0, k0] -> [m, m0, k, k0] + _src = var[cache.opr[0]].permute([0, 2, 1, 3]).contiguous() + # [n, k, n0, k0] -> [k, k0, n, n0] + _wei = var[cache.opr[1]].permute([1, 3, 0, 2]).contiguous() + + # [m, m0, k, k0] -> [M, K] + src = _src.reshape([_src.shape[0] * _src.shape[1], _src.shape[2] * _src.shape[3]]) + # [k, k0, n, n0] -> [K, N] + wei = _wei.reshape([_wei.shape[0] * _wei.shape[1], _wei.shape[2] * _wei.shape[3]]) + + dst = torch.mm(src, wei) + # [M, N] -> [m, m0, n, n0] + dst = dst.reshape([_src.shape[0], _src.shape[1], _wei.shape[-2], _wei.shape[-1]]) + + # [m, m0, n, n0] -> [m, n, m0, n0] + return (dst.transpose(1, 2).contiguous(),) + + +def mlir_mmt4d(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.mmt4d(arg0, arg1, outs=[args[2].get_zero_op(ctx)]) + ], + ) + + +def ref_vecmat( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return ( + torch.matmul(var[cache.opr[0]].unsqueeze(-2), var[cache.opr[1]]).squeeze(-2), + ) + + +def mlir_vecmat(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.vecmat(arg0, arg1, outs=[args[2].get_zero_op(ctx)]) + ], + ) diff --git a/test/benchgc/src/benchgc/linalg/misc.py b/test/benchgc/src/benchgc/linalg/misc.py new file mode 100644 index 000000000..cf672956c --- /dev/null +++ b/test/benchgc/src/benchgc/linalg/misc.py @@ -0,0 +1,97 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import argparse +import copy +from typing import Dict, List, Tuple + +import benchgc.util +import gc_mlir.ir +import torch +from benchgc.arg import Arg +from benchgc.mlir.module import init_module +from benchgc.mlir.util import MLIRCache +from gc_mlir._mlir_libs._mlir.ir import DenseI64ArrayAttr +from gc_mlir.dialects import linalg +from gc_mlir.dialects.linalg.opdsl.lang.comprehension import TypeFnType + + +# 1. use to reshape to match ndim +# 2. perform broadcast +def ref_broadcast( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + dst_shape: List[int] = op.results[0].type.shape + tmp_shape = copy.copy(dst_shape) + dimensions: DenseI64ArrayAttr = op.attributes["dimensions"] + for d in dimensions: + tmp_shape[d] = 1 + + return (var[cache.opr[0]].reshape(tmp_shape).broadcast_to(dst_shape).contiguous(),) + + +def mlir_broadcast(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + + return init_module( + (args[0],), + (args[1],), + lambda ctx, arg0: [ + linalg.broadcast( + arg0, outs=[args[1].get_zero_op(ctx)], dimensions=flags.dimensions + ) + ], + ) + + +def ref_fill( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return (torch.full(tuple(op.results[0].type.shape), var[cache.opr[0]]),) + + +def mlir_fill(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0],), + (args[1],), + lambda ctx, arg0: [ + linalg.fill( + arg0, outs=[args[1].get_zero_op(ctx)], dimensions=flags.dimensions + ) + ], + ) + + +def ref_copy( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return ( + var[cache.opr[0]] + .to(benchgc.util.get_dtype(str(op.result.type.element_type))) + .clone(), + ) + + +def mlir_copy(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + + return init_module( + (args[0],), + (args[1],), + lambda ctx, arg0: [ + linalg.copy( + arg0, outs=[args[1].get_zero_op(ctx)], cast=TypeFnType(flags.cast) + ) + ], + ) diff --git a/test/benchgc/src/benchgc/linalg/pool.py b/test/benchgc/src/benchgc/linalg/pool.py new file mode 100644 index 000000000..9779256df --- /dev/null +++ b/test/benchgc/src/benchgc/linalg/pool.py @@ -0,0 +1,489 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import argparse +from typing import Dict, List, Tuple + +import gc_mlir.ir +import torch +from benchgc.arg import Arg +from benchgc.mlir.module import init_module +from benchgc.mlir.util import MLIRCache +from gc_mlir.dialects import linalg + + +def ref_pooling_nchw_max( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + return ( + torch.max_pool2d( + var[cache.opr[0]], + kernel_size=var[cache.opr[1]].shape, + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + ), + ) + + +def mlir_pooling_nchw_max( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.pooling_nchw_max( + arg0, + arg1, + outs=[args[2].get_min_value_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_pooling_nchw_sum( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + + # pytorch does not support pooling on sum + # avg_pool2d or lp_pool2d with p = 1 does not support dilation + # we will use depthwise convolution with kernel equals to 1 to calculate the sum + + # FIXME: improve the code if pytorch support the sum pooling with dilation + + channel = var[cache.opr[0]].shape[1] + kernel = var[cache.opr[1]] + return ( + torch.conv2d( + var[cache.opr[0]], + torch.ones(channel, 1, kernel.shape[0], kernel.shape[1]), + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + groups=channel, + ), + ) + + +def mlir_pooling_nchw_sum( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.pooling_nchw_sum( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_pooling_ncw_max( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + return ( + torch.max_pool1d( + var[cache.opr[0]], + kernel_size=var[cache.opr[1]].shape, + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + ), + ) + + +def mlir_pooling_ncw_max( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.pooling_ncw_max( + arg0, + arg1, + outs=[args[2].get_min_value_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_pooling_ncw_sum( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + + # pytorch does not support pooling on sum + # avg_pool1d or lp_pool1d with p = 1 does not support dilation + # we will use depthwise convolution with kernel equals to 1 to calculate the sum + + # FIXME: improve the code if pytorch support the sum pooling with dilation + + channel = var[cache.opr[0]].shape[1] + kernel = var[cache.opr[1]] + return ( + torch.conv1d( + var[cache.opr[0]], + torch.ones(channel, 1, kernel.shape[0]), + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + groups=channel, + ), + ) + + +def mlir_pooling_ncw_sum( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.pooling_ncw_sum( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_pooling_ndhwc_max( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + return ( + torch.max_pool3d( + var[cache.opr[0]].permute([0, -1, 1, 2, 3]), + kernel_size=var[cache.opr[1]].shape, + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + ) + .permute([0, 2, 3, 4, 1]) + .contiguous(), + ) + + +def mlir_pooling_ndhwc_max( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.pooling_ndhwc_max( + arg0, + arg1, + outs=[args[2].get_min_value_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_pooling_ndhwc_sum( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + + # pytorch does not support pooling on sum + # avg_pool3d or lp_pool3d with p = 1 does not support dilation + # we will use depthwise convolution with kernel equals to 1 to calculate the sum + + # FIXME: improve the code if pytorch support the sum pooling with dilation + + channel = var[cache.opr[0]].shape[-1] + kernel = var[cache.opr[1]] + return ( + torch.conv3d( + var[cache.opr[0]].permute([0, -1, 1, 2, 3]), + torch.ones(channel, 1, kernel.shape[0], kernel.shape[1], kernel.shape[2]), + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + groups=channel, + ) + .permute([0, 2, 3, 4, 1]) + .contiguous(), + ) + + +def mlir_pooling_ndhwc_sum( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.pooling_ndhwc_sum( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_pooling_nhwc_max( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + return ( + torch.max_pool2d( + var[cache.opr[0]].permute([0, -1, 1, 2]), + kernel_size=var[cache.opr[1]].shape, + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + ) + .permute([0, 2, 3, 1]) + .contiguous(), + ) + + +def mlir_pooling_nhwc_max( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.pooling_nhwc_max( + arg0, + arg1, + outs=[args[2].get_min_value_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_pooling_nhwc_sum( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + + # pytorch does not support pooling on sum + # avg_pool2d or lp_pool2d with p = 1 does not support dilation + # we will use depthwise convolution with kernel equals to 1 to calculate the sum + + # FIXME: improve the code if pytorch support the sum pooling with dilation + + channel = var[cache.opr[0]].shape[-1] + kernel = var[cache.opr[1]] + return ( + torch.conv2d( + var[cache.opr[0]].permute([0, -1, 1, 2]), + torch.ones(channel, 1, kernel.shape[0], kernel.shape[1]), + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + groups=channel, + ) + .permute([0, 2, 3, 1]) + .contiguous(), + ) + + +def mlir_pooling_nhwc_sum( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.pooling_nhwc_sum( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_pooling_nhwc_min( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + return ( + torch.max_pool2d( + var[cache.opr[0]].permute([0, -1, 1, 2]).neg(), + kernel_size=var[cache.opr[1]].shape, + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + ) + .permute([0, 2, 3, 1]) + .neg() + .contiguous(), + ) + + +def mlir_pooling_nhwc_min( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.pooling_nhwc_min( + arg0, + arg1, + outs=[args[2].get_max_value_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_pooling_nwc_max( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + return ( + torch.max_pool1d( + var[cache.opr[0]].permute([0, -1, 1]), + kernel_size=var[cache.opr[1]].shape, + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + ) + .permute([0, 2, 1]) + .contiguous(), + ) + + +def mlir_pooling_nwc_max( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.pooling_nwc_max( + arg0, + arg1, + outs=[args[2].get_min_value_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_pooling_nwc_min( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + return ( + torch.max_pool1d( + var[cache.opr[0]].permute([0, -1, 1]).neg(), + kernel_size=var[cache.opr[1]].shape, + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + ) + .permute([0, 2, 1]) + .contiguous() + .neg(), + ) + + +def mlir_pooling_nwc_min( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.pooling_nwc_min( + arg0, + arg1, + outs=[args[2].get_max_value_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) + + +def ref_pooling_nwc_sum( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + + # pytorch does not support pooling on sum + # avg_pool3d or lp_pool3d with p = 1 does not support dilation + # we will use depthwise convolution with kernel equals to 1 to calculate the sum + + # FIXME: improve the code if pytorch support the sum pooling with dilation + + channel = var[cache.opr[0]].shape[-1] + kernel = var[cache.opr[1]] + return ( + torch.conv1d( + var[cache.opr[0]].permute([0, -1, 1]), + torch.ones(channel, 1, kernel.shape[0]), + stride=[strides[i] for i in range(len(strides))], + dilation=[dilations[i] for i in range(len(dilations))], + groups=channel, + ) + .permute([0, 2, 1]) + .contiguous(), + ) + + +def mlir_pooling_nwc_sum( + flags: argparse.Namespace, args: List[Arg] +) -> gc_mlir.ir.Module: + return init_module( + (args[0], args[1]), + (args[2],), + lambda ctx, arg0, arg1: [ + linalg.pooling_nwc_sum( + arg0, + arg1, + outs=[args[2].get_zero_op(ctx)], + strides=flags.strides, + dilations=flags.dilations, + ) + ], + ) diff --git a/test/benchgc/src/benchgc/linalg/softmax.py b/test/benchgc/src/benchgc/linalg/softmax.py new file mode 100644 index 000000000..20ed39fcb --- /dev/null +++ b/test/benchgc/src/benchgc/linalg/softmax.py @@ -0,0 +1,47 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import argparse +from typing import Dict, List, Tuple + +import gc_mlir.ir +import torch +from benchgc.arg import Arg +from benchgc.mlir.module import init_module +from benchgc.mlir.util import MLIRCache +from gc_mlir.dialects import linalg + + +def ref_softmax( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + dimension: gc_mlir.ir.IntegerAttr = op.attributes["dimension"] + return (torch.softmax(var[cache.opr[0]], dimension.value),) + + +def mlir_softmax(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: + return init_module( + (args[0],), + (args[1],), + lambda ctx, arg0: [ + linalg.softmax( + result=[args[1].get_ranked_tensor_type(ctx)], + input=arg0, + output=args[1].get_zero_op(ctx), + dimension=flags.dimension, + ) + ], + ) diff --git a/test/benchgc/src/benchgc/mlir/CMakeLists.txt b/test/benchgc/src/benchgc/mlir/CMakeLists.txt new file mode 100644 index 000000000..5f8d589b4 --- /dev/null +++ b/test/benchgc/src/benchgc/mlir/CMakeLists.txt @@ -0,0 +1,22 @@ +################################################################################ +# Copyright (C) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# SPDX-License-Identifier: Apache-2.0 +################################################################################ + + +file(GLOB PYTHON_SCRIPTS "*.py") +foreach(PY_SCRIPT ${PYTHON_SCRIPTS}) + configure_file(${PY_SCRIPT} ${CMAKE_BINARY_DIR}/test/benchgc/src/benchgc/mlir/ COPYONLY) +endforeach() diff --git a/test/benchgc/src/benchgc/mlir/__init__.py b/test/benchgc/src/benchgc/mlir/__init__.py new file mode 100644 index 000000000..4d3e897ce --- /dev/null +++ b/test/benchgc/src/benchgc/mlir/__init__.py @@ -0,0 +1,15 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ diff --git a/test/benchgc/src/benchgc/mlir/arg.py b/test/benchgc/src/benchgc/mlir/arg.py new file mode 100644 index 000000000..364b9d92c --- /dev/null +++ b/test/benchgc/src/benchgc/mlir/arg.py @@ -0,0 +1,174 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import ctypes +from typing import Any, List + +import benchgc.util +import gc_mlir.dialects.arith +import gc_mlir.dialects.linalg +import gc_mlir.dialects.tensor +import gc_mlir.ir +import torch +from benchgc.mlir.util import dtype_to_ctype, str_to_mlir_dtype, str_to_mlir_typed_attr + + +# scalar should give a address +# map torch.Tensor -> memref +# map int address -> scalar value +def get_mlir_args(args: List[torch.Tensor | int]): + mlir_args: List[Any] = [] + + for arg in args: + if isinstance(arg, torch.Tensor): + mlir_args.append(ctypes.pointer(ctypes.pointer(get_md(arg)))) + else: + mlir_args.append(ctypes.c_void_p(arg)) + + return mlir_args + + +def get_md(tensor: torch.Tensor): + if tensor.ndim == 0: + + class _0dMemrefDescriptor(ctypes.Structure): + _fields_ = [ + ("allocated", ctypes.c_longlong), + ("aligned", ctypes.POINTER(dtype_to_ctype(tensor.dtype))), + ("offset", ctypes.c_longlong), + ] + + md = _0dMemrefDescriptor() + else: + ctype_shape = ctypes.c_longlong * tensor.ndim + ctype_strides = ctypes.c_longlong * tensor.ndim + + class _ndMemrefDescriptor(ctypes.Structure): + _fields_ = [ + ("allocated", ctypes.c_longlong), + ("aligned", ctypes.POINTER(dtype_to_ctype(tensor.dtype))), + ("offset", ctypes.c_longlong), + ("shape", ctype_shape), + ("strides", ctype_strides), + ] + + md = _ndMemrefDescriptor() + md.shape = ctype_shape(*tensor.shape) + md.strides = ctype_strides(*tensor.stride()) + + md.allocated = tensor.data_ptr() + md.aligned = ctypes.cast( + ctypes.c_void_p(tensor.data_ptr()), ctypes.POINTER(dtype_to_ctype(tensor.dtype)) + ) + md.offset = ctypes.c_longlong(0) + return md + + +class MLIRArg: + dtype: str + shape: List[int] + + scalar: bool + + def __init__(self) -> None: + self.dtype = "" + + # md format: + # 0d memref/tensor: 0xf32 + # nd memref/tensor: 2x3xf32 + # scalar: f32 + def set_md(self, md: str): + splited: List[str] = md.split("x") + self.dtype = splited[-1] + self.shape = [] + + for dim in splited[:-1]: + self.shape.append(int(dim)) + self.set_scalar() + + def set_scalar(self): + # use 0xf32 to represent memref + # use f32 to represent f32 + if self.shape == [0]: + self.shape = [] + self.scalar = False + elif self.shape == []: + self.scalar = True + else: + self.scalar = False + + def nelem(self) -> int: + if self.scalar or self.shape == [] or self.shape[0] == 0: + return 1 + ret: int = 1 + for dim in self.shape: + ret = ret * dim + return ret + + def get_mlir_type(self, ctx: gc_mlir.ir.Context) -> gc_mlir.ir.Type: + if self.scalar: + return str_to_mlir_dtype(ctx, self.dtype) + else: + return gc_mlir.ir.RankedTensorType.get( + self.shape, str_to_mlir_dtype(ctx, self.dtype) + ) + + def get_ranked_tensor_type( + self, ctx: gc_mlir.ir.Context + ) -> gc_mlir.ir.RankedTensorType: + return gc_mlir.ir.RankedTensorType.get( + self.shape, str_to_mlir_dtype(ctx, self.dtype) + ) + + def get_constant_op( + self, ctx: gc_mlir.ir.Context, cst: Any + ) -> gc_mlir.dialects.tensor.OpView: + zero = gc_mlir.dialects.arith.ConstantOp( + value=str_to_mlir_typed_attr(ctx, self.dtype, cst), + result=str_to_mlir_dtype(ctx, self.dtype), + ) + if self.scalar: + return zero + else: + return gc_mlir.dialects.linalg.fill( + zero, + outs=[ + gc_mlir.dialects.tensor.EmptyOp( + self.shape, str_to_mlir_dtype(ctx, self.dtype) + ) + ], + ) + + def get_zero_op(self, ctx: gc_mlir.ir.Context) -> gc_mlir.dialects.tensor.OpView: + return self.get_constant_op(ctx, 0) + + def get_max_value_op( + self, ctx: gc_mlir.ir.Context + ) -> gc_mlir.dialects.tensor.OpView: + dtype = benchgc.util.get_dtype(self.dtype) + if dtype.is_floating_point: + return self.get_constant_op(ctx, torch.finfo(dtype).max) + else: + return self.get_constant_op(ctx, torch.iinfo(dtype).max) + + def get_min_value_op( + self, ctx: gc_mlir.ir.Context + ) -> gc_mlir.dialects.tensor.OpView: + dtype = benchgc.util.get_dtype(self.dtype) + if dtype.is_floating_point: + return self.get_constant_op(ctx, torch.finfo(dtype).min) + else: + return self.get_constant_op(ctx, torch.iinfo(dtype).min) diff --git a/test/benchgc/src/benchgc/mlir/module.py b/test/benchgc/src/benchgc/mlir/module.py new file mode 100644 index 000000000..806c9d8b7 --- /dev/null +++ b/test/benchgc/src/benchgc/mlir/module.py @@ -0,0 +1,48 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +from typing import Callable, List, Tuple + +import gc_mlir.dialects.tensor +import gc_mlir.ir +from benchgc.mlir.arg import MLIRArg +from gc_mlir.dialects import func + + +def init_module( + inputs: Tuple[MLIRArg, ...], + outputs: Tuple[MLIRArg, ...], + op_func: Callable[ + [gc_mlir.ir.Context, Tuple[gc_mlir.ir.BlockArgument, ...]], + List[gc_mlir.ir.OpResult], + ], +) -> gc_mlir.ir.Module: + with gc_mlir.ir.Context() as ctx, gc_mlir.ir.Location.unknown(): + module = gc_mlir.ir.Module.create() + with gc_mlir.ir.InsertionPoint(module.body): + f = func.FuncOp( + name="entry", + type=gc_mlir.ir.FunctionType.get( + inputs=[x.get_mlir_type(ctx) for x in inputs], + results=[x.get_mlir_type(ctx) for x in outputs], + ), + ) + f.attributes["llvm.emit_c_interface"] = gc_mlir.ir.UnitAttr.get() + + with gc_mlir.ir.InsertionPoint(f.add_entry_block()): + block_args = f.entry_block.arguments + func.ReturnOp(op_func(ctx, *block_args)) + return module diff --git a/test/benchgc/src/benchgc/mlir/util.py b/test/benchgc/src/benchgc/mlir/util.py new file mode 100644 index 000000000..24169bca1 --- /dev/null +++ b/test/benchgc/src/benchgc/mlir/util.py @@ -0,0 +1,111 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import ctypes +from typing import Any, List + +import gc_mlir.ir +import torch +from gc_mlir.dialects import func + +# only python 3.11 support +# from typing import Self + + +def get_entry(module: gc_mlir.ir.Module, entry: str = '"entry"') -> func.FuncOp: + for op in module.operation.opview.regions[0].blocks[0].operations: + if str(op.name) == entry: + return op + raise Exception(f"entry function {entry} is not found at the top level") + + +# calling python binding consumes a lot of time e.g. get_name() +# we need to cache some result to avoid duplicate call +class MLIRCache: + # operand name cache + opr: List[str] + # result name cache + res: List[str] + # argument name cache + arg: List[str] + # next hierarchy + next = [] # List[Self] + + def __init__(self): + self.opr = [] + self.res = [] + self.arg = [] + self.next = [] + + +def dtype_to_ctype(dtype: torch.dtype): + if dtype == torch.float32: + return ctypes.c_float + elif dtype == torch.float64: + return ctypes.c_double + elif dtype == torch.int32: + return ctypes.c_int + elif dtype == torch.int64: + return ctypes.c_longlong + elif dtype == torch.uint8: + return ctypes.c_ubyte + elif dtype == torch.int8: + return ctypes.c_byte + elif dtype == torch.int16 or dtype == torch.bfloat16 or torch.float16: + return ctypes.c_short + elif dtype == torch.bool: + return ctypes.c_bool + else: + raise ValueError(f"Unsupported torch dtype: {dtype}") + + +def str_to_mlir_dtype(ctx: gc_mlir.ir.Context, dtype: str) -> gc_mlir.ir.Type: + if dtype == "f32": + return gc_mlir.ir.F32Type.get(ctx) + elif dtype == "f64": + return gc_mlir.ir.F64Type.get(ctx) + elif dtype == "f16": + return gc_mlir.ir.F16Type.get(ctx) + elif dtype == "bf16": + return gc_mlir.ir.BF16Type.get(ctx) + elif dtype == "u8": + return gc_mlir.ir.IntegerType.get_unsigned(8, ctx) + elif dtype == "s8": + return gc_mlir.ir.IntegerType.get_signed(8, ctx) + elif dtype == "boolean": + return gc_mlir.ir.IntegerType.get_unsigned(1, ctx) + elif dtype == "f8_e4m3": + return gc_mlir.ir.Float8E4M3FNType.get(ctx) + elif dtype == "f8_e5m2": + return gc_mlir.ir.Float8E5M2Type.get(ctx) + elif dtype == "s32": + return gc_mlir.ir.IntegerType.get_signed(32, ctx) + else: + raise Exception(f"data type not support: {dtype}") + + +def str_to_mlir_typed_attr( + ctx: gc_mlir.ir.Context, dtype: str, value: Any +) -> gc_mlir.ir.Attribute: + mlir_dtype = str_to_mlir_dtype(ctx, dtype) + if dtype in ["f32", "f64", "bf16", "f16", "f8_e4m3", "f8_e5m2"]: + return gc_mlir.ir.FloatAttr.get(mlir_dtype, value) + elif dtype in ["u8", "s8", "s32"]: + return gc_mlir.ir.IntegerAttr.get(mlir_dtype, value) + elif dtype == "boolean": + return gc_mlir.ir.BoolAttr.get(value) + else: + raise Exception(f"data type not support: {dtype}") diff --git a/test/benchgc/src/benchgc/runner.py b/test/benchgc/src/benchgc/runner.py new file mode 100644 index 000000000..80178baa8 --- /dev/null +++ b/test/benchgc/src/benchgc/runner.py @@ -0,0 +1,109 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +from typing import Dict, Tuple + +import gc_mlir._mlir_libs +import gc_mlir.dialects +import gc_mlir.dialects.func +import gc_mlir.ir +import torch +from benchgc.arith import ref_op as arith_ref_op +from benchgc.linalg import ref_op as linalg_ref_op +from benchgc.mlir.util import MLIRCache +from benchgc.tensor import ref_op as tensor_ref_op + + +def dfs_op( + cache: MLIRCache, op: gc_mlir.ir.OpView, tensors: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + + dialect_call: str = str(op.name) + if dialect_call in ["func.return", "linalg.yield"]: + ret: Tuple[torch.Tensor, ...] = tuple() + for name in cache.opr: + ret = ret + (tensors[name],) + return ret + if dialect_call.startswith("linalg"): + ref_op = linalg_ref_op + elif dialect_call.startswith("tensor"): + ref_op = tensor_ref_op + elif dialect_call.startswith("arith"): + ref_op = arith_ref_op + else: + build_cache = len(cache.next) == 0 + for i in range(len(op.regions)): + if build_cache: + # we do not need to cache things for region + # keep an empty cache + cache.next.append(MLIRCache()) + ret = dfs_region(cache.next[i], op.regions[i], tensors) + if len(ret) != 0: + return ret + return tuple() + + dialect_op: str = dialect_call.split(".")[1] + if dialect_op not in ref_op: + raise Exception(f"unknown op call {dialect_call}") + ref_func = ref_op[dialect_op] + for i, res in enumerate(ref_func(cache, op, tensors)): + tensors[cache.res[i]] = res + return tuple() + + +def dfs_region( + cache: MLIRCache, region: gc_mlir.ir.Region, tensors: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + build_cache = len(cache.next) == 0 + for i in range(len(region.blocks)): + if build_cache: + _cache = MLIRCache() + # we need to cache argument name for block object + for arg in region.blocks[i].arguments: + _cache.arg.append(arg.get_name()) + cache.next.append(_cache) + ret = dfs_block(cache.next[i], region.blocks[i], tensors) + if len(ret) != 0: + return ret + return tuple() + + +def dfs_block( + cache: MLIRCache, block: gc_mlir.ir.Block, tensors: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + build_cache = len(cache.next) == 0 + for i in range(len(block.operations)): + if build_cache: + _cache = MLIRCache() + # we need to cache operand name and result name + for opr in block.operations[i].operands: + _cache.opr.append(opr.get_name()) + + for res in block.operations[i].results: + _cache.res.append(res.get_name()) + cache.next.append(_cache) + + ret = dfs_op(cache.next[i], block.operations[i], tensors) + if len(ret) != 0: + return ret + return tuple() + + +def ref_run( + entry: gc_mlir.dialects.func.FuncOp, tensors: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + # cache some information of block & op + return dfs_op(MLIRCache(), entry, tensors) diff --git a/test/benchgc/src/benchgc/tensor/CMakeLists.txt b/test/benchgc/src/benchgc/tensor/CMakeLists.txt new file mode 100644 index 000000000..7b1b990dc --- /dev/null +++ b/test/benchgc/src/benchgc/tensor/CMakeLists.txt @@ -0,0 +1,22 @@ +################################################################################ +# Copyright (C) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# SPDX-License-Identifier: Apache-2.0 +################################################################################ + + +file(GLOB PYTHON_SCRIPTS "*.py") +foreach(PY_SCRIPT ${PYTHON_SCRIPTS}) + configure_file(${PY_SCRIPT} ${CMAKE_BINARY_DIR}/test/benchgc/src/benchgc/tensor/ COPYONLY) +endforeach() diff --git a/test/benchgc/src/benchgc/tensor/__init__.py b/test/benchgc/src/benchgc/tensor/__init__.py new file mode 100644 index 000000000..2f8bc98a4 --- /dev/null +++ b/test/benchgc/src/benchgc/tensor/__init__.py @@ -0,0 +1,45 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import argparse +import importlib +from typing import Callable, Dict, Tuple + +import gc_mlir.ir +import torch +from benchgc.arg import Arg +from benchgc.mlir.util import MLIRCache + +ref_op: Dict[ + str, + Callable[ + [MLIRCache, gc_mlir.ir.OpView, Dict[str, torch.Tensor]], + Tuple[torch.Tensor, ...], + ], +] = {} +mlir_op: Dict[ + str, Callable[[argparse.Namespace, Dict[str, Arg]], gc_mlir.ir.Module] +] = {} + +for dri in ["basic", "shape"]: + mod = importlib.import_module(f"benchgc.tensor.{dri}") + for key in mod.__dict__: + if key.startswith("ref_"): + op: str = key.removeprefix("ref_") + ref_op[op] = mod.__dict__[key] + if key.startswith("mlir_"): + op: str = key.removeprefix("mlir_") + mlir_op[op] = mod.__dict__[key] diff --git a/test/benchgc/src/benchgc/tensor/basic.py b/test/benchgc/src/benchgc/tensor/basic.py new file mode 100644 index 000000000..eb56aafbc --- /dev/null +++ b/test/benchgc/src/benchgc/tensor/basic.py @@ -0,0 +1,33 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +from typing import Dict, Tuple + +import benchgc.util +import gc_mlir.ir +import torch +from benchgc.mlir.util import MLIRCache + + +def ref_empty( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + return ( + torch.zeros( + size=op.results[0].type.shape, + dtype=benchgc.util.get_dtype(str(op.results[0].type.element_type)), + ), + ) diff --git a/test/benchgc/src/benchgc/tensor/shape.py b/test/benchgc/src/benchgc/tensor/shape.py new file mode 100644 index 000000000..18d9fbb2c --- /dev/null +++ b/test/benchgc/src/benchgc/tensor/shape.py @@ -0,0 +1,59 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +from typing import Dict, List, Tuple + +import gc_mlir.ir +import torch +from benchgc.mlir.util import MLIRCache + + +def ref_collapse_shape( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + # permute axis and do reshape + reassociation: gc_mlir.ir.ArrayAttr = op.attributes["reassociation"] + permutation: List[int] = [] + shape: List[int] = [] + for outdim in reassociation: + d: int = 1 + for indim in outdim: + permutation.append(int(indim)) + d = d * int(op.operands[0].type.shape[int(indim)]) + shape.append(d) + return ( + torch.permute(var[cache.opr[0]], tuple(permutation)) + .contiguous() + .reshape(shape) + .contiguous(), + ) + + +def ref_expand_shape( + cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + # permute axis and do reshape + reassociation: gc_mlir.ir.ArrayAttr = op.attributes["reassociation"] + permutation: List[int] = [0] * len(op.result.type.shape) + shape: List[int] = [] + + d: int = 0 + for indim in reassociation: + for outdim in indim: + shape.append(int(op.result.type.shape[int(outdim)])) + permutation[int(outdim)] = d + d = d + 1 + return (torch.reshape(var[cache.opr[0]], shape).permute(permutation).contiguous(),) diff --git a/test/benchgc/src/benchgc/util.py b/test/benchgc/src/benchgc/util.py new file mode 100644 index 000000000..de275f0fd --- /dev/null +++ b/test/benchgc/src/benchgc/util.py @@ -0,0 +1,334 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import operator +import random +from functools import reduce +from typing import Any, Callable, List, Tuple, Union + +import numpy +import torch + +# verbose level +NO_VERBOSE = 0 +MODULE_VERBOSE = 1 # print the module will be executed +ARG_VERBOSE = 2 # + print arg information +COMPARE_VERBOSE = 3 # + print threshold for comparison +ERROR_OUTPUT_VERBOSE = 4 # + print all error data points if failed +OUTPUT_VERBOSE = 5 # + print all result including passed tensor +INPUT_VERBOSE = 6 # + print input torch tensors + +""" +acc | acc | elems | value_range | worst case +s32 | mul | 10 | 3 | 3^10=2^16, out of 2^30 (max integer) +f16 | mul | 10 | 1 | (2^1)^10=2^10, out of 2^16 (max exponent) +f32 | mul | 30 | 3 | (2^3)^30=2^90, out of 2^128 (max exponent) +s32 | sum | 10000 | 50 | 10000*50=2^19, out of 2^30 (max integer) +f16 | sum | 1000 | 8 | 1000*8=2^13, out of 2^10 (max mantissa/integer) +f32 | sum | 10000 | 16 | 10000*16=2^18, out of 2^23 (max mantissa/integer) + min/max | all | 1000 | no limits on accumulation chain + +In f16 cases, the worst case exceeds the data type bounds, however it's rare +to reach these extreme cases as long as they're close (can't just use f32 bounds) +""" +# first: nonneutral elements +# second: maximum range +_problem_bounds = { + "mul_int": (10, 3), + "mul_fp16": (10, 1), + "mul_fp32": (30, 3), + "sum_int": (10000, 50), + "sum_fp16": (1000, 8), + "sum_fp32": (10000, 16), + "minmax_int": (-1, 1000), + "minmax_fp": (-1, 1000), +} +_dtype_2_range = { + "f32": (-16777216, 16777216), + "f64": (-16777216, 16777216), + "f16": (-2048, 2048), + "bf16": (-16777216, 16777216), + "f8_e5m2": (-2048, 2048), + "f8_e4m3": (-2048, 2048), + "u8": (0, 255), + "s8": (-128, 127), + "s32": (-2147483648, 2147483520), +} + + +def flip_coin( + seed: Union[Any, torch.Tensor], prob: Union[float, torch.Tensor] +) -> Union[bool, torch.Tensor]: + big_prime: int = 1000003 + prime: int = 753737 + seed = seed * prime + return (seed % big_prime) < (prob * big_prime) + + +def get_problem_bounds(kind: str, dt: torch.dtype) -> Tuple[int, int]: + if not dt.is_floating_point: + if kind in ["max", "min"]: + return _problem_bounds["minmax_int"] + elif kind == "mul": + return _problem_bounds["mul_int"] + else: + return _problem_bounds["sum_int"] + elif kind in ["max", "min"]: + return _problem_bounds["minmax_fp"] + elif kind == "mul": + return ( + _problem_bounds["mul_fp16"] + if dt == torch.float16 + else _problem_bounds["mul_fp32"] + ) + else: + return ( + _problem_bounds["sum_fp16"] + if dt == torch.float16 + else _problem_bounds["sum_fp32"] + ) + + +def get_type_range(dt: str) -> Tuple[float, float]: + return _dtype_2_range[dt] + + +# Lnorm, Bnorm & Conv +def get_digits(dtype: str) -> int: + return { + "f32": 24, + "f64": 53, + "s8": 7, + "u8": 8, + "f16": 11, + "bf16": 8, + "f8_e5m2": 3, + "f8_e4m3": 4, + }[dtype] + + +def get_dtype(dtype: str) -> torch.dtype: + if dtype == "f32": + return torch.float32 + elif dtype == "f64": + return torch.float64 + elif dtype == "f16": + return torch.float16 + elif dtype == "bf16": + return torch.bfloat16 + elif dtype == "u8" or dtype == "ui8": + return torch.uint8 + elif dtype == "s8" or dtype == "i8": + return torch.int8 + elif dtype == "boolean": + return torch.uint8 + elif dtype == "f8_e4m3": + return torch.float8_e4m3fn + elif dtype == "f8_e5m2": + return torch.float8_e5m2 + elif dtype == "s32" or dtype == "i32": + return torch.int32 + else: + raise Exception(f"data type not support: {dtype}") + + +def get_eps(dtype: torch.dtype) -> float: + return torch.finfo(dtype).eps if dtype.is_floating_point else 0.0 + + +_seed: int = 0 + + +def set_seed(seed: int): + global _seed + _seed = seed + + +def torch_seed(seed_scale: int = 1, seed_shift: int = 0): + seed: int = _seed * seed_scale + seed_shift + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + numpy.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def iterate_tensor(tensor: torch.Tensor, fn: Callable[[Tuple[int, ...]], None]): + index: List[int] = [0] * tensor.ndim + + def dfs(depth: int): + if depth == tensor.ndim: + fn(tuple(index)) + else: + for i in range(tensor.shape[depth]): + index[depth] = i + dfs(depth + 1) + + +# indicate how to check the result +class Checker: + use_norm: bool + # set if negative result is trancated to zero + truncate_negative: bool + eltwise_relax: bool + threshold: float + zero_percent: float + # args: [ref, res, abs_diff, rel_diff] + customized_checker: ( + Callable[[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor] + | None + ) + + def __init__( + self, + threshold: float, + zero_percent: float, + use_norm: bool = False, + eltwise_relax: bool = False, + truncate_negative: bool = False, + checker: ( + Callable[ + [torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor + ] + | None + ) = None, + ) -> None: + self.use_norm = use_norm + self.eltwise_relax = eltwise_relax + self.threshold = threshold + self.zero_percent = zero_percent + self.truncate_negative = truncate_negative + self.customized_checker = checker + + def check( + self, ref: torch.Tensor, res: torch.Tensor, verbose: int + ) -> Tuple[bool, bool | None]: + if self.use_norm: + return self.norm(ref, res, verbose) + else: + return self.p2p(ref, res, verbose) + + def norm( + self, ref: torch.Tensor, res: torch.Tensor, verbose: int + ) -> Tuple[bool, bool | None]: + + f32_ref = ref.to(torch.float32) + f32_res = res.to(torch.float32) + if f32_ref.nelement() == 0: + return (True, None) + + diff_square_sum = torch.square(torch.subtract(f32_ref, f32_res)).sum() + square_sum = torch.square(f32_ref).sum() + + l2_diff_norm = torch.sqrt(diff_square_sum / square_sum).item() + if verbose >= COMPARE_VERBOSE: + print(f"norm check: {l2_diff_norm:.10f} / threshold: {self.threshold:.10f}") + + return (l2_diff_norm < self.threshold, None) + + def p2p( + self, ref: torch.Tensor, res: torch.Tensor, verbose: int + ) -> Tuple[bool, bool | None]: + + if verbose >= COMPARE_VERBOSE: + print(f"p2p check: threshold: {self.threshold:.7f}") + f32_ref = ref.to(torch.float32) + f32_res = res.to(torch.float32) + + check = torch.BoolTensor([False]) + + check = check.bitwise_or(torch.bitwise_and(f32_ref.isnan(), f32_res.isnan())) + check = check.bitwise_or( + torch.bitwise_and(f32_ref.isneginf(), f32_res.isneginf()) + ) + check = check.bitwise_or( + torch.bitwise_and(f32_ref.isposinf(), f32_res.isposinf()) + ) + + # choose diff/rel_diff based on value + abs_diff = (f32_ref - f32_res).abs() + rel_diff = abs_diff / torch.where( + f32_ref.abs() > numpy.finfo(numpy.float32).smallest_subnormal, + f32_ref.abs(), + 1, + ) + # pick a diff for comparison + diff = torch.where(f32_ref.abs() > 1e-5, rel_diff, abs_diff) + + check = check.bitwise_or(diff <= self.threshold) + + if self.eltwise_relax: + check = check.bitwise_or(abs_diff <= max(torch.finfo(res.dtype).eps, 2e-5)) + + if self.customized_checker is not None: + check = check.bitwise_or( + self.customized_checker(ref, res, abs_diff, rel_diff) + ) + + if verbose >= OUTPUT_VERBOSE: + iterate_tensor( + check, + lambda idx: print( + "%20s: ref: %12.7f res: %12.7f abs_diff: %12.7f rel_diff: %12.7f" + % ( + idx, + f32_ref[idx].item(), + f32_res[idx].item(), + abs_diff[idx].item(), + rel_diff[idx].item(), + ) + ), + ) + if check.all(): + # check mistrusted + zero = res.nelement() - res.count_nonzero().item() + if res.nelement() < 10: + mistrust = False + elif self.truncate_negative: + mistrust = ( + zero * 100.0 / res.nelement() > 50.0 + self.zero_percent / 2.0 + ) + else: + mistrust = zero * 100.0 / res.nelement() > self.zero_percent + return (True, mistrust) + else: + if ( + verbose < OUTPUT_VERBOSE + ): # skip verbose print if full output tensor is alrady printed + fail = torch.argwhere(torch.where(check, 0, 1)) + if verbose < ERROR_OUTPUT_VERBOSE: + fail = fail[ + :10 + ] # only print top 10 failed data points if verbose level does not satisfied + for idx in fail: + index: Tuple[int, ...] = tuple(idx.tolist()) + print( + "%20s: ref: %12.7f res: %12.7f abs_diff: %12.7f rel_diff: %12.7f" + % ( + index, + f32_ref[index].item(), + f32_res[index].item(), + abs_diff[index].item(), + rel_diff[index].item(), + ) + ) + return (False, None) + + +def nelem(shape: List[int]) -> int: + return reduce(operator.mul, shape) From 1cabc2c5357bfe1fd3074a97c15967fb55fc6f21 Mon Sep 17 00:00:00 2001 From: "Xu, Rui" Date: Mon, 26 Aug 2024 18:29:29 -0700 Subject: [PATCH 04/38] remove print --- python/gc_mlir/_mlir_libs/_site_initialize_0.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/gc_mlir/_mlir_libs/_site_initialize_0.py b/python/gc_mlir/_mlir_libs/_site_initialize_0.py index 3fba4fbdd..8cc5ce301 100644 --- a/python/gc_mlir/_mlir_libs/_site_initialize_0.py +++ b/python/gc_mlir/_mlir_libs/_site_initialize_0.py @@ -20,4 +20,4 @@ def context_init_hook(context): register_onednn_graph_dialect(context) except ModuleNotFoundError: - print("onednn_graph dialect not found") + pass From e316a981b5a4876549343d16125289db137f92df Mon Sep 17 00:00:00 2001 From: "Xu, Rui" Date: Mon, 26 Aug 2024 19:11:02 -0700 Subject: [PATCH 05/38] fix --- test/benchgc/src/benchgc/__main__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/benchgc/src/benchgc/__main__.py b/test/benchgc/src/benchgc/__main__.py index 9beca9409..bb8160191 100644 --- a/test/benchgc/src/benchgc/__main__.py +++ b/test/benchgc/src/benchgc/__main__.py @@ -175,9 +175,9 @@ def add_bench_options(parser: argparse.ArgumentParser): parser.add_argument("--entry", type=str, default="main_entry") -def get_pattern_clz(diver_str: str): +def get_pattern_clz(driver_str: str): """Function getting Pattern class by name.""" - clz = {"mlp": MLP}[diver_str] + clz = {"mlp": MLP}[driver_str] return clz @@ -188,7 +188,7 @@ def add_pattern_options(parser: argparse.ArgumentParser): get_pattern_clz(pattern_name).add_args(parser) -def get_moudle_and_args(flags): +def get_module_and_args(flags): args: List[Arg] = [] @@ -355,7 +355,7 @@ def performance_testing(flags, module, args): add_pattern_options(arg_parser) flags = arg_parser.parse_args() benchgc.util.set_seed(flags.seed) - ir_module, module_args = get_moudle_and_args(flags) + ir_module, module_args = get_module_and_args(flags) if flags.mode == "C": correctness_testing(flags, ir_module, module_args) elif flags.mode == "P": From 841e81fa7e8bffd8d17806bad554ad6e2d96177b Mon Sep 17 00:00:00 2001 From: "Xu, Rui" Date: Mon, 26 Aug 2024 21:58:46 -0700 Subject: [PATCH 06/38] simplify --- test/benchgc/src/benchgc/__main__.py | 63 ++++--- test/benchgc/src/benchgc/arith/__init__.py | 8 +- test/benchgc/src/benchgc/arith/basic.py | 12 +- test/benchgc/src/benchgc/linalg/__init__.py | 6 +- test/benchgc/src/benchgc/linalg/binary.py | 30 ++-- test/benchgc/src/benchgc/linalg/conv.py | 186 +++++++++----------- test/benchgc/src/benchgc/linalg/eltwise.py | 50 +++--- test/benchgc/src/benchgc/linalg/generic.py | 18 +- test/benchgc/src/benchgc/linalg/matmul.py | 64 +++---- test/benchgc/src/benchgc/linalg/misc.py | 17 +- test/benchgc/src/benchgc/linalg/pool.py | 122 ++++++------- test/benchgc/src/benchgc/linalg/softmax.py | 8 +- test/benchgc/src/benchgc/mlir/arg.py | 28 +-- test/benchgc/src/benchgc/mlir/module.py | 20 +-- test/benchgc/src/benchgc/mlir/util.py | 50 +++--- test/benchgc/src/benchgc/pattern/mlp.py | 20 +-- test/benchgc/src/benchgc/runner.py | 8 +- test/benchgc/src/benchgc/tensor/__init__.py | 8 +- test/benchgc/src/benchgc/tensor/basic.py | 4 +- test/benchgc/src/benchgc/tensor/shape.py | 10 +- 20 files changed, 333 insertions(+), 399 deletions(-) diff --git a/test/benchgc/src/benchgc/__main__.py b/test/benchgc/src/benchgc/__main__.py index bb8160191..95f2f472c 100644 --- a/test/benchgc/src/benchgc/__main__.py +++ b/test/benchgc/src/benchgc/__main__.py @@ -20,13 +20,10 @@ import sys from typing import Dict, List -import gc_mlir.ir -import runner -import torch -from gc_mlir.graph_compiler import GraphCompiler - import benchgc.mlir.util import benchgc.util +import runner +import torch from benchgc.arg import ( compare_tensor, fill_tensor, @@ -35,10 +32,14 @@ ) from benchgc.arg.arg import Arg from benchgc.mlir.arg import get_mlir_args -from benchgc.pattern.mlp import MLP from benchgc.mlir.bench import mlir_wrapper_bench, py_timeit_bench +from benchgc.pattern.mlp import MLP +from gc_mlir import ir +from gc_mlir.graph_compiler import GraphCompiler + def add_common_options(parser: argparse.ArgumentParser): + """common options for benchgc""" parser.add_argument( "--mode", required=False, @@ -184,35 +185,39 @@ def get_pattern_clz(driver_str: str): def add_pattern_options(parser: argparse.ArgumentParser): '''add options for each pattern''' if parser.parse_known_args()[0].driver == "pattern": - pattern_name = parser.parse_known_args()[0].driver + pattern_name = parser.parse_known_args()[0].case get_pattern_clz(pattern_name).add_args(parser) def get_module_and_args(flags): - args: List[Arg] = [] - - if flags.driver == "mlir": + if flags.driver in ["mlir", "pattern"]: # we need to find all args by reading the entry function - with open(flags.case, "r") as mlir_file: - with gc_mlir.ir.Context() as ctx: - module = gc_mlir.ir.Module.parse(mlir_file.read()) - entry = benchgc.mlir.util.get_entry(module) - idx: int = 0 - # FIXME: only support RankTensorType now - for i in entry.type.inputs: - args.append(Arg(idx)) - args[-1].dtype = str(i.element_type) - args[-1].shape = list(i.shape) - args[-1].set_scalar() - idx += 1 - - for o in entry.type.results: - args.append(Arg(idx)) - args[-1].dtype = str(o.element_type) - args[-1].shape = list(o.shape) - args[-1].set_scalar() - idx += 1 + with ir.Context() as ctx: + if flags.driver == "mlir": + with open(flags.case, "r") as mlir_file: + module = ir.Module.parse(mlir_file.read()) + elif flags.driver == "pattern": + pattern_clz = get_pattern_clz(flags.case) + module = pattern_clz(ctx, flags).ir_module + + entry = benchgc.mlir.util.get_entry(module) + idx: int = 0 + # FIXME: only support RankTensorType now + for i in entry.type.inputs: + args.append(Arg(idx)) + args[-1].dtype = str(i.element_type) + args[-1].shape = list(i.shape) + args[-1].set_scalar() + idx += 1 + + for o in entry.type.results: + args.append(Arg(idx)) + args[-1].dtype = str(o.element_type) + args[-1].shape = list(o.shape) + args[-1].set_scalar() + idx += 1 + elif flags.driver in ["linalg"]: # all arg shape/dt should be provided in single op test for i in range(len(flags.md)): diff --git a/test/benchgc/src/benchgc/arith/__init__.py b/test/benchgc/src/benchgc/arith/__init__.py index a5f942a72..42d6dd0aa 100644 --- a/test/benchgc/src/benchgc/arith/__init__.py +++ b/test/benchgc/src/benchgc/arith/__init__.py @@ -18,21 +18,19 @@ import importlib from typing import Callable, Dict, List, Tuple -import gc_mlir.ir import torch from benchgc.arg import Arg from benchgc.mlir.util import MLIRCache +from gc_mlir import ir ref_op: Dict[ str, Callable[ - [MLIRCache, gc_mlir.ir.OpView, Dict[str, torch.Tensor]], + [MLIRCache, ir.OpView, Dict[str, torch.Tensor]], Tuple[torch.Tensor, ...], ], ] = {} -mlir_op: Dict[ - str, Callable[[argparse.Namespace, List[Arg], List[Arg]], gc_mlir.ir.Module] -] = {} +mlir_op: Dict[str, Callable[[argparse.Namespace, List[Arg], List[Arg]], ir.Module]] = {} for dri in ["basic"]: mod = importlib.import_module(f"benchgc.arith.{dri}") diff --git a/test/benchgc/src/benchgc/arith/basic.py b/test/benchgc/src/benchgc/arith/basic.py index 7e4b17467..2da0aa022 100644 --- a/test/benchgc/src/benchgc/arith/basic.py +++ b/test/benchgc/src/benchgc/arith/basic.py @@ -18,20 +18,20 @@ import benchgc.util import gc_mlir._mlir_libs._mlir.ir -import gc_mlir.ir import torch from benchgc.mlir.util import MLIRCache +from gc_mlir import ir def ref_constant( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: value = op.attributes["value"] - if isinstance(value, gc_mlir._mlir_libs._mlir.ir.FloatAttr): + if isinstance(value, ir.FloatAttr): return ( torch.full(size=tuple(), fill_value=value.__float__(), dtype=torch.float), ) - elif isinstance(value, gc_mlir._mlir_libs._mlir.ir.DenseFPElementsAttr): + elif isinstance(value, ir.DenseFPElementsAttr): if value.is_splat: return ( torch.full( @@ -47,12 +47,12 @@ def ref_constant( def ref_mulf( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (var[cache.opr[0]] * var[cache.opr[1]],) def ref_addf( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (var[cache.opr[0]] + var[cache.opr[1]],) diff --git a/test/benchgc/src/benchgc/linalg/__init__.py b/test/benchgc/src/benchgc/linalg/__init__.py index 331bd75dd..e75068c9a 100644 --- a/test/benchgc/src/benchgc/linalg/__init__.py +++ b/test/benchgc/src/benchgc/linalg/__init__.py @@ -18,19 +18,19 @@ import importlib from typing import Callable, Dict, List, Tuple -import gc_mlir.ir import torch from benchgc.arg import Arg from benchgc.mlir.util import MLIRCache +from gc_mlir import ir ref_op: Dict[ str, Callable[ - [MLIRCache, gc_mlir.ir.OpView, Dict[str, torch.Tensor]], + [MLIRCache, ir.OpView, Dict[str, torch.Tensor]], Tuple[torch.Tensor, ...], ], ] = {} -mlir_op: Dict[str, Callable[[argparse.Namespace, List[Arg]], gc_mlir.ir.Module]] = {} +mlir_op: Dict[str, Callable[[argparse.Namespace, List[Arg]], ir.Module]] = {} for dri in [ "binary", diff --git a/test/benchgc/src/benchgc/linalg/binary.py b/test/benchgc/src/benchgc/linalg/binary.py index ed5d280a3..b8508d8a0 100644 --- a/test/benchgc/src/benchgc/linalg/binary.py +++ b/test/benchgc/src/benchgc/linalg/binary.py @@ -17,21 +17,21 @@ import argparse from typing import Dict, List, Tuple -import gc_mlir.ir import torch from benchgc.arg import Arg from benchgc.mlir.module import init_module from benchgc.mlir.util import MLIRCache +from gc_mlir import ir from gc_mlir.dialects import linalg def ref_add( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.add(var[cache.opr[0]], var[cache.opr[1]]),) -def mlir_add(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_add(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -42,12 +42,12 @@ def mlir_add(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: def ref_powf( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.pow(var[cache.opr[0]], var[cache.opr[1]]),) -def mlir_powf(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_powf(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -58,12 +58,12 @@ def mlir_powf(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: def ref_div( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.div(var[cache.opr[0]], var[cache.opr[1]]),) -def mlir_div(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_div(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -74,12 +74,12 @@ def mlir_div(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: def ref_max( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.max(var[cache.opr[0]], var[cache.opr[1]]),) -def mlir_max(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_max(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -90,12 +90,12 @@ def mlir_max(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: def ref_min( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.min(var[cache.opr[0]], var[cache.opr[1]]),) -def mlir_min(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_min(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -106,12 +106,12 @@ def mlir_min(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: def ref_mul( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.mul(var[cache.opr[0]], var[cache.opr[1]]),) -def mlir_mul(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_mul(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -122,12 +122,12 @@ def mlir_mul(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: def ref_sub( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.sub(var[cache.opr[0]], var[cache.opr[1]]),) -def mlir_sub(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_sub(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), diff --git a/test/benchgc/src/benchgc/linalg/conv.py b/test/benchgc/src/benchgc/linalg/conv.py index c8fc38efb..359d0d586 100644 --- a/test/benchgc/src/benchgc/linalg/conv.py +++ b/test/benchgc/src/benchgc/linalg/conv.py @@ -17,19 +17,19 @@ import argparse from typing import Dict, List, Tuple -import gc_mlir.ir import torch from benchgc.arg import Arg from benchgc.mlir.module import init_module from benchgc.mlir.util import MLIRCache +from gc_mlir import ir from gc_mlir.dialects import linalg def ref_conv_1d_ncw_fcw( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] return ( torch.conv1d( var[cache.opr[0]], @@ -40,9 +40,7 @@ def ref_conv_1d_ncw_fcw( ) -def mlir_conv_1d_ncw_fcw( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_conv_1d_ncw_fcw(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -59,10 +57,10 @@ def mlir_conv_1d_ncw_fcw( def ref_conv_1d_nwc_wcf( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] # src: nwc -> ncw # wei: wcf -> fcw @@ -80,9 +78,7 @@ def ref_conv_1d_nwc_wcf( ) -def mlir_conv_1d_nwc_wcf( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_conv_1d_nwc_wcf(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -99,10 +95,10 @@ def mlir_conv_1d_nwc_wcf( def ref_conv_1d_ncw_fcw( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] return ( torch.conv1d( var[cache.opr[0]], @@ -113,9 +109,7 @@ def ref_conv_1d_ncw_fcw( ) -def mlir_conv_1d_ncw_fcw( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_conv_1d_ncw_fcw(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -132,7 +126,7 @@ def mlir_conv_1d_ncw_fcw( def ref_conv_1d( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return ( torch.conv1d( @@ -144,7 +138,7 @@ def ref_conv_1d( ) -def mlir_conv_1d(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_conv_1d(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -159,10 +153,10 @@ def mlir_conv_1d(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Modul def ref_conv_2d_nchw_fchw( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] return ( torch.conv2d( var[cache.opr[0]], @@ -173,9 +167,7 @@ def ref_conv_2d_nchw_fchw( ) -def mlir_conv_2d_nchw_fchw( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_conv_2d_nchw_fchw(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -192,10 +184,10 @@ def mlir_conv_2d_nchw_fchw( def ref_conv_2d_ngchw_fgchw( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] src = var[cache.opr[0]] wei = var[cache.opr[1]] @@ -221,9 +213,7 @@ def ref_conv_2d_ngchw_fgchw( ) # split group axis from output channel -def mlir_conv_2d_ngchw_fgchw( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_conv_2d_ngchw_fgchw(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -240,10 +230,10 @@ def mlir_conv_2d_ngchw_fgchw( def ref_conv_2d_ngchw_gfchw( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] src = var[cache.opr[0]] wei = var[cache.opr[1]] @@ -267,9 +257,7 @@ def ref_conv_2d_ngchw_gfchw( ) # split group axis from output channel -def mlir_conv_2d_ngchw_gfchw( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_conv_2d_ngchw_gfchw(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -286,10 +274,10 @@ def mlir_conv_2d_ngchw_gfchw( def ref_conv_2d_nhwc_fhwc( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] return ( torch.conv2d( var[cache.opr[0]].permute([0, 3, 1, 2]), @@ -302,9 +290,7 @@ def ref_conv_2d_nhwc_fhwc( ) -def mlir_conv_2d_nhwc_fhwc( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_conv_2d_nhwc_fhwc(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -321,10 +307,10 @@ def mlir_conv_2d_nhwc_fhwc( def ref_conv_2d_nhwc_hwcf( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] return ( torch.conv2d( var[cache.opr[0]].permute([0, 3, 1, 2]), @@ -337,9 +323,7 @@ def ref_conv_2d_nhwc_hwcf( ) -def mlir_conv_2d_nhwc_hwcf( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_conv_2d_nhwc_hwcf(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -356,7 +340,7 @@ def mlir_conv_2d_nhwc_hwcf( def ref_conv_2d( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return ( torch.conv2d( @@ -368,7 +352,7 @@ def ref_conv_2d( ) -def mlir_conv_2d(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_conv_2d(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -383,10 +367,10 @@ def mlir_conv_2d(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Modul def ref_conv_3d_ncdhw_fcdhw( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] return ( torch.conv3d( var[cache.opr[0]], @@ -397,9 +381,7 @@ def ref_conv_3d_ncdhw_fcdhw( ) -def mlir_conv_3d_ncdhw_fcdhw( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_conv_3d_ncdhw_fcdhw(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -416,10 +398,10 @@ def mlir_conv_3d_ncdhw_fcdhw( def ref_conv_3d_ndhwc_dhwcf( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] return ( torch.conv3d( var[cache.opr[0]].permute([0, 4, 1, 2, 3]), @@ -432,9 +414,7 @@ def ref_conv_3d_ndhwc_dhwcf( ) -def mlir_conv_3d_ndhwc_dhwcf( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_conv_3d_ndhwc_dhwcf(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -451,7 +431,7 @@ def mlir_conv_3d_ndhwc_dhwcf( def ref_conv_3d( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return ( torch.conv3d( @@ -463,7 +443,7 @@ def ref_conv_3d( ) -def mlir_conv_3d(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_conv_3d(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -478,10 +458,10 @@ def mlir_conv_3d(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Modul def ref_depthwise_conv_1d_ncw_cw( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] groups: int = var[cache.opr[0]].shape[1] return ( torch.conv1d( @@ -496,7 +476,7 @@ def ref_depthwise_conv_1d_ncw_cw( def mlir_depthwise_conv_1d_ncw_cw( flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -513,10 +493,10 @@ def mlir_depthwise_conv_1d_ncw_cw( def ref_depthwise_conv_1d_nwc_wc( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] groups: int = var[cache.opr[0]].shape[-1] return ( torch.conv1d( @@ -533,7 +513,7 @@ def ref_depthwise_conv_1d_nwc_wc( def mlir_depthwise_conv_1d_nwc_wc( flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -550,10 +530,10 @@ def mlir_depthwise_conv_1d_nwc_wc( def ref_depthwise_conv_1d_nwc_wcm( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] src = var[cache.opr[0]] groups: int = src.shape[-1] wei = var[cache.opr[1]] @@ -575,7 +555,7 @@ def ref_depthwise_conv_1d_nwc_wcm( def mlir_depthwise_conv_1d_nwc_wcm( flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -592,10 +572,10 @@ def mlir_depthwise_conv_1d_nwc_wcm( def ref_depthwise_conv_2d_nchw_chw( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] groups: int = var[cache.opr[0]].shape[1] return ( torch.conv2d( @@ -610,7 +590,7 @@ def ref_depthwise_conv_2d_nchw_chw( def mlir_depthwise_conv_2d_nchw_chw( flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -627,10 +607,10 @@ def mlir_depthwise_conv_2d_nchw_chw( def ref_depthwise_conv_2d_nhwc_hwc( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] groups: int = var[cache.opr[0]].shape[-1] return ( torch.conv2d( @@ -647,7 +627,7 @@ def ref_depthwise_conv_2d_nhwc_hwc( def mlir_depthwise_conv_2d_nhwc_hwc( flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -664,10 +644,10 @@ def mlir_depthwise_conv_2d_nhwc_hwc( def ref_depthwise_conv_2d_nhwc_hwcm( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] groups: int = var[cache.opr[0]].shape[-1] wei = var[cache.opr[1]] dst = ( @@ -692,7 +672,7 @@ def ref_depthwise_conv_2d_nhwc_hwcm( def mlir_depthwise_conv_2d_nhwc_hwcm( flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -709,10 +689,10 @@ def mlir_depthwise_conv_2d_nhwc_hwcm( def ref_depthwise_conv_3d_ncdhw_cdhw( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] groups: int = var[cache.opr[0]].shape[1] return ( torch.conv3d( @@ -727,7 +707,7 @@ def ref_depthwise_conv_3d_ncdhw_cdhw( def mlir_depthwise_conv_3d_ncdhw_cdhw( flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -744,10 +724,10 @@ def mlir_depthwise_conv_3d_ncdhw_cdhw( def ref_depthwise_conv_3d_ndhwc_dhwc( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] groups: int = var[cache.opr[0]].shape[-1] return ( torch.conv3d( @@ -764,7 +744,7 @@ def ref_depthwise_conv_3d_ndhwc_dhwc( def mlir_depthwise_conv_3d_ndhwc_dhwc( flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -781,10 +761,10 @@ def mlir_depthwise_conv_3d_ndhwc_dhwc( def ref_depthwise_conv_3d_ndhwc_dhwcm( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] groups: int = var[cache.opr[0]].shape[-1] wei = var[cache.opr[1]] dst = ( @@ -818,7 +798,7 @@ def ref_depthwise_conv_3d_ndhwc_dhwcm( def mlir_depthwise_conv_3d_ndhwc_dhwcm( flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), diff --git a/test/benchgc/src/benchgc/linalg/eltwise.py b/test/benchgc/src/benchgc/linalg/eltwise.py index 7ae9b31b7..d6e5a9c5d 100644 --- a/test/benchgc/src/benchgc/linalg/eltwise.py +++ b/test/benchgc/src/benchgc/linalg/eltwise.py @@ -17,21 +17,21 @@ import argparse from typing import Dict, List, Tuple -import gc_mlir.ir import torch from benchgc.arg import Arg from benchgc.mlir.module import init_module from benchgc.mlir.util import MLIRCache +from gc_mlir import ir from gc_mlir.dialects import linalg def ref_abs( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.abs(var[cache.opr[0]]),) -def mlir_abs(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_abs(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0],), (args[1],), @@ -40,12 +40,12 @@ def mlir_abs(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: def ref_ceil( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.ceil(var[cache.opr[0]]),) -def mlir_ceil(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_ceil(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0],), (args[1],), @@ -54,12 +54,12 @@ def mlir_ceil(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: def ref_floor( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.floor(var[cache.opr[0]]),) -def mlir_floor(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_floor(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0],), (args[1],), @@ -68,12 +68,12 @@ def mlir_floor(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: def ref_erf( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.erf(var[cache.opr[0]]),) -def mlir_erf(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_erf(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0],), (args[1],), @@ -81,7 +81,7 @@ def mlir_erf(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: ) -def mlir_log(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_log(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0],), (args[1],), @@ -90,12 +90,12 @@ def mlir_log(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: def ref_log( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.log(var[cache.opr[0]]),) -def mlir_negf(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_negf(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0],), (args[1],), @@ -104,18 +104,18 @@ def mlir_negf(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: def ref_negf( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.neg(var[cache.opr[0]]),) def ref_exp( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.exp(var[cache.opr[0]]),) -def mlir_exp(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_exp(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0],), (args[1],), @@ -124,7 +124,7 @@ def mlir_exp(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: def ref_round( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: # torch.round is following the priciple "round half to even" # we need another implementation @@ -133,7 +133,7 @@ def ref_round( return (v + torch.where(var[cache.opr[0]] - v >= 0.5, 1, 0),) -def mlir_round(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_round(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0],), (args[1],), @@ -142,12 +142,12 @@ def mlir_round(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: def ref_rsqrt( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.rsqrt(var[cache.opr[0]]),) -def mlir_rsqrt(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_rsqrt(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0],), (args[1],), @@ -156,12 +156,12 @@ def mlir_rsqrt(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: def ref_sqrt( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.sqrt(var[cache.opr[0]]),) -def mlir_sqrt(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_sqrt(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0],), (args[1],), @@ -170,12 +170,12 @@ def mlir_sqrt(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: def ref_square( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.square(var[cache.opr[0]]),) -def mlir_square(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_square(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0],), (args[1],), @@ -184,12 +184,12 @@ def mlir_square(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module def ref_tanh( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.tanh(var[cache.opr[0]]),) -def mlir_tanh(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_tanh(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0],), (args[1],), diff --git a/test/benchgc/src/benchgc/linalg/generic.py b/test/benchgc/src/benchgc/linalg/generic.py index 67228ab47..6cfded39d 100644 --- a/test/benchgc/src/benchgc/linalg/generic.py +++ b/test/benchgc/src/benchgc/linalg/generic.py @@ -18,14 +18,14 @@ import benchgc.runner import benchgc.util -import gc_mlir.ir import torch from benchgc.mlir.util import MLIRCache +from gc_mlir import ir def generic_loop( cache: MLIRCache, - op: gc_mlir.ir.OpView, + op: ir.OpView, depth: int, iterspace: Dict[str, Tuple[int, int, int]], affine_from: List[str], @@ -42,7 +42,7 @@ def generic_loop( # region cache cache.next.append(MLIRCache()) - block: gc_mlir.ir.Block = op.regions[0].blocks[0] + block: ir.Block = op.regions[0].blocks[0] if len(cache.next[0].next) == 0: # region->block cache cache.next[0].next.append(MLIRCache()) @@ -96,7 +96,7 @@ def generic_loop( def ref_generic( - cache: MLIRCache, op: gc_mlir.ir.OpView, tensors: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, tensors: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: affine_from: List[str] = [] affine_to: List[List[str]] = [] @@ -110,7 +110,7 @@ def ref_generic( # TODO: support affine expression iterspace: Dict[str, Tuple[int, int, int]] = {} - operands: List[gc_mlir.ir.OpOperand] = list(op.operands) + operands: List[ir.OpOperand] = list(op.operands) loop_var: Dict[str, torch.Tensor] = {} for d in affine_from: @@ -142,7 +142,7 @@ def ref_generic( def reduce_loop( cache: MLIRCache, - op: gc_mlir.ir.OpView, + op: ir.OpView, depth: int, in_shape: List[int], var: Dict[str, torch.Tensor], @@ -155,7 +155,7 @@ def reduce_loop( # we need to execute the block here # we will need to read the block argument name and save it into the cache - block: gc_mlir.ir.Block = op.regions[0].blocks[0] + block: ir.Block = op.regions[0].blocks[0] if len(cache.next) == 0: # region cache @@ -180,7 +180,7 @@ def reduce_loop( # perform the yield operation result_tensor[tuple(out_idx)] = res[0] else: - dimensions: gc_mlir.ir.DenseI64ArrayAttr = op.attributes["dimensions"] + dimensions: ir.DenseI64ArrayAttr = op.attributes["dimensions"] reduce_axis: bool = depth in list(dimensions) for i in range(in_shape[depth]): @@ -214,7 +214,7 @@ def reduce_loop( def ref_reduce( - cache: MLIRCache, op: gc_mlir.ir.OpView, tensors: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, tensors: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: # create the buffer for result tensors tensors[cache.res[0]] = tensors[cache.opr[-1]].clone() diff --git a/test/benchgc/src/benchgc/linalg/matmul.py b/test/benchgc/src/benchgc/linalg/matmul.py index 9efde9612..1c36d9dbd 100644 --- a/test/benchgc/src/benchgc/linalg/matmul.py +++ b/test/benchgc/src/benchgc/linalg/matmul.py @@ -17,22 +17,22 @@ import argparse from typing import Dict, List, Tuple -import gc_mlir.ir import torch from benchgc.arg import Arg from benchgc.mlir.module import init_module from benchgc.mlir.util import MLIRCache +from gc_mlir import ir from gc_mlir.dialects import linalg from gc_mlir.dialects.linalg.opdsl.lang.comprehension import TypeFnType def ref_batch_matmul( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.matmul(var[cache.opr[0]], var[cache.opr[1]]),) -def mlir_batch_matmul(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_batch_matmul(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -43,14 +43,14 @@ def mlir_batch_matmul(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir. def ref_batch_matmul_transpose_a( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.bmm(var[cache.opr[0]].transpose(-1, -2), var[cache.opr[1]]),) def mlir_batch_matmul_transpose_a( flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -61,14 +61,14 @@ def mlir_batch_matmul_transpose_a( def ref_batch_matmul_transpose_b( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.bmm(var[cache.opr[0]], var[cache.opr[1]].transpose(-1, -2)),) def mlir_batch_matmul_transpose_b( flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -79,7 +79,7 @@ def mlir_batch_matmul_transpose_b( def ref_batch_matvec( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: # pytorch does not support bmv return ( @@ -87,7 +87,7 @@ def ref_batch_matvec( ) -def mlir_batch_matvec(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_batch_matvec(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -98,7 +98,7 @@ def mlir_batch_matvec(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir. def ref_batch_mmt4d( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: # [B, m, k, m0, k0] -> [B, m, m0, k, k0] _src = var[cache.opr[0]].permute([0, 1, 3, 2, 4]).contiguous() @@ -124,7 +124,7 @@ def ref_batch_mmt4d( return (dst.transpose(2, 3).contiguous(),) -def mlir_batch_mmt4d(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_batch_mmt4d(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -135,7 +135,7 @@ def mlir_batch_mmt4d(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.M def ref_batch_reduce_matmul( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return ( torch.addbmm( @@ -148,9 +148,7 @@ def ref_batch_reduce_matmul( ) -def mlir_batch_reduce_matmul( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_batch_reduce_matmul(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -161,14 +159,14 @@ def mlir_batch_reduce_matmul( def ref_batch_vecmat( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return ( torch.matmul(var[cache.opr[0]].unsqueeze(-2), var[cache.opr[1]]).squeeze(-2), ) -def mlir_batch_vecmat(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_batch_vecmat(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -179,12 +177,12 @@ def mlir_batch_vecmat(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir. def ref_dot( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.dot(var[cache.opr[0]], var[cache.opr[1]]),) -def mlir_dot(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_dot(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -195,12 +193,12 @@ def mlir_dot(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: def ref_matmul( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.mm(var[cache.opr[0]], var[cache.opr[1]]),) -def mlir_matmul(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_matmul(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -213,14 +211,12 @@ def mlir_matmul(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module def ref_matmul_transpose_a( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.mm(var[cache.opr[0]].transpose(-1, -2), var[cache.opr[1]]),) -def mlir_matmul_transpose_a( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_matmul_transpose_a(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -233,14 +229,12 @@ def mlir_matmul_transpose_a( def ref_matmul_transpose_b( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.mm(var[cache.opr[0]], var[cache.opr[1]].transpose(-1, -2)),) -def mlir_matmul_transpose_b( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_matmul_transpose_b(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -253,12 +247,12 @@ def mlir_matmul_transpose_b( def ref_matvec( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.mv(var[cache.opr[0]], var[cache.opr[1]]),) -def mlir_matvec(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_matvec(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -269,7 +263,7 @@ def mlir_matvec(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module def ref_mmt4d( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: # [m, k, m0, k0] -> [m, m0, k, k0] _src = var[cache.opr[0]].permute([0, 2, 1, 3]).contiguous() @@ -289,7 +283,7 @@ def ref_mmt4d( return (dst.transpose(1, 2).contiguous(),) -def mlir_mmt4d(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_mmt4d(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -300,14 +294,14 @@ def mlir_mmt4d(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: def ref_vecmat( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return ( torch.matmul(var[cache.opr[0]].unsqueeze(-2), var[cache.opr[1]]).squeeze(-2), ) -def mlir_vecmat(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_vecmat(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), diff --git a/test/benchgc/src/benchgc/linalg/misc.py b/test/benchgc/src/benchgc/linalg/misc.py index cf672956c..6020e54db 100644 --- a/test/benchgc/src/benchgc/linalg/misc.py +++ b/test/benchgc/src/benchgc/linalg/misc.py @@ -19,12 +19,11 @@ from typing import Dict, List, Tuple import benchgc.util -import gc_mlir.ir import torch from benchgc.arg import Arg from benchgc.mlir.module import init_module from benchgc.mlir.util import MLIRCache -from gc_mlir._mlir_libs._mlir.ir import DenseI64ArrayAttr +from gc_mlir import ir from gc_mlir.dialects import linalg from gc_mlir.dialects.linalg.opdsl.lang.comprehension import TypeFnType @@ -32,18 +31,18 @@ # 1. use to reshape to match ndim # 2. perform broadcast def ref_broadcast( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: dst_shape: List[int] = op.results[0].type.shape tmp_shape = copy.copy(dst_shape) - dimensions: DenseI64ArrayAttr = op.attributes["dimensions"] + dimensions: ir.DenseI64ArrayAttr = op.attributes["dimensions"] for d in dimensions: tmp_shape[d] = 1 return (var[cache.opr[0]].reshape(tmp_shape).broadcast_to(dst_shape).contiguous(),) -def mlir_broadcast(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_broadcast(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0],), @@ -57,12 +56,12 @@ def mlir_broadcast(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Mod def ref_fill( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.full(tuple(op.results[0].type.shape), var[cache.opr[0]]),) -def mlir_fill(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_fill(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0],), (args[1],), @@ -75,7 +74,7 @@ def mlir_fill(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: def ref_copy( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return ( var[cache.opr[0]] @@ -84,7 +83,7 @@ def ref_copy( ) -def mlir_copy(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_copy(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0],), diff --git a/test/benchgc/src/benchgc/linalg/pool.py b/test/benchgc/src/benchgc/linalg/pool.py index 9779256df..1b2f03412 100644 --- a/test/benchgc/src/benchgc/linalg/pool.py +++ b/test/benchgc/src/benchgc/linalg/pool.py @@ -17,19 +17,19 @@ import argparse from typing import Dict, List, Tuple -import gc_mlir.ir import torch from benchgc.arg import Arg from benchgc.mlir.module import init_module from benchgc.mlir.util import MLIRCache +from gc_mlir import ir from gc_mlir.dialects import linalg def ref_pooling_nchw_max( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] return ( torch.max_pool2d( var[cache.opr[0]], @@ -40,9 +40,7 @@ def ref_pooling_nchw_max( ) -def mlir_pooling_nchw_max( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_pooling_nchw_max(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -59,10 +57,10 @@ def mlir_pooling_nchw_max( def ref_pooling_nchw_sum( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] # pytorch does not support pooling on sum # avg_pool2d or lp_pool2d with p = 1 does not support dilation @@ -83,9 +81,7 @@ def ref_pooling_nchw_sum( ) -def mlir_pooling_nchw_sum( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_pooling_nchw_sum(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -102,10 +98,10 @@ def mlir_pooling_nchw_sum( def ref_pooling_ncw_max( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] return ( torch.max_pool1d( var[cache.opr[0]], @@ -116,9 +112,7 @@ def ref_pooling_ncw_max( ) -def mlir_pooling_ncw_max( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_pooling_ncw_max(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -135,10 +129,10 @@ def mlir_pooling_ncw_max( def ref_pooling_ncw_sum( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] # pytorch does not support pooling on sum # avg_pool1d or lp_pool1d with p = 1 does not support dilation @@ -159,9 +153,7 @@ def ref_pooling_ncw_sum( ) -def mlir_pooling_ncw_sum( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_pooling_ncw_sum(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -178,10 +170,10 @@ def mlir_pooling_ncw_sum( def ref_pooling_ndhwc_max( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] return ( torch.max_pool3d( var[cache.opr[0]].permute([0, -1, 1, 2, 3]), @@ -194,9 +186,7 @@ def ref_pooling_ndhwc_max( ) -def mlir_pooling_ndhwc_max( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_pooling_ndhwc_max(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -213,10 +203,10 @@ def mlir_pooling_ndhwc_max( def ref_pooling_ndhwc_sum( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] # pytorch does not support pooling on sum # avg_pool3d or lp_pool3d with p = 1 does not support dilation @@ -239,9 +229,7 @@ def ref_pooling_ndhwc_sum( ) -def mlir_pooling_ndhwc_sum( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_pooling_ndhwc_sum(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -258,10 +246,10 @@ def mlir_pooling_ndhwc_sum( def ref_pooling_nhwc_max( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] return ( torch.max_pool2d( var[cache.opr[0]].permute([0, -1, 1, 2]), @@ -274,9 +262,7 @@ def ref_pooling_nhwc_max( ) -def mlir_pooling_nhwc_max( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_pooling_nhwc_max(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -293,10 +279,10 @@ def mlir_pooling_nhwc_max( def ref_pooling_nhwc_sum( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] # pytorch does not support pooling on sum # avg_pool2d or lp_pool2d with p = 1 does not support dilation @@ -319,9 +305,7 @@ def ref_pooling_nhwc_sum( ) -def mlir_pooling_nhwc_sum( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_pooling_nhwc_sum(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -338,10 +322,10 @@ def mlir_pooling_nhwc_sum( def ref_pooling_nhwc_min( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] return ( torch.max_pool2d( var[cache.opr[0]].permute([0, -1, 1, 2]).neg(), @@ -355,9 +339,7 @@ def ref_pooling_nhwc_min( ) -def mlir_pooling_nhwc_min( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_pooling_nhwc_min(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -374,10 +356,10 @@ def mlir_pooling_nhwc_min( def ref_pooling_nwc_max( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] return ( torch.max_pool1d( var[cache.opr[0]].permute([0, -1, 1]), @@ -390,9 +372,7 @@ def ref_pooling_nwc_max( ) -def mlir_pooling_nwc_max( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_pooling_nwc_max(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -409,10 +389,10 @@ def mlir_pooling_nwc_max( def ref_pooling_nwc_min( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] return ( torch.max_pool1d( var[cache.opr[0]].permute([0, -1, 1]).neg(), @@ -426,9 +406,7 @@ def ref_pooling_nwc_min( ) -def mlir_pooling_nwc_min( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_pooling_nwc_min(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), @@ -445,10 +423,10 @@ def mlir_pooling_nwc_min( def ref_pooling_nwc_sum( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] # pytorch does not support pooling on sum # avg_pool3d or lp_pool3d with p = 1 does not support dilation @@ -471,9 +449,7 @@ def ref_pooling_nwc_sum( ) -def mlir_pooling_nwc_sum( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_pooling_nwc_sum(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0], args[1]), (args[2],), diff --git a/test/benchgc/src/benchgc/linalg/softmax.py b/test/benchgc/src/benchgc/linalg/softmax.py index 20ed39fcb..fe18a0017 100644 --- a/test/benchgc/src/benchgc/linalg/softmax.py +++ b/test/benchgc/src/benchgc/linalg/softmax.py @@ -17,22 +17,22 @@ import argparse from typing import Dict, List, Tuple -import gc_mlir.ir import torch from benchgc.arg import Arg from benchgc.mlir.module import init_module from benchgc.mlir.util import MLIRCache +from gc_mlir import ir from gc_mlir.dialects import linalg def ref_softmax( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - dimension: gc_mlir.ir.IntegerAttr = op.attributes["dimension"] + dimension: ir.IntegerAttr = op.attributes["dimension"] return (torch.softmax(var[cache.opr[0]], dimension.value),) -def mlir_softmax(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_softmax(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( (args[0],), (args[1],), diff --git a/test/benchgc/src/benchgc/mlir/arg.py b/test/benchgc/src/benchgc/mlir/arg.py index 364b9d92c..e48a1aed9 100644 --- a/test/benchgc/src/benchgc/mlir/arg.py +++ b/test/benchgc/src/benchgc/mlir/arg.py @@ -21,9 +21,9 @@ import gc_mlir.dialects.arith import gc_mlir.dialects.linalg import gc_mlir.dialects.tensor -import gc_mlir.ir import torch from benchgc.mlir.util import dtype_to_ctype, str_to_mlir_dtype, str_to_mlir_typed_attr +from gc_mlir import ir # scalar should give a address @@ -118,24 +118,18 @@ def nelem(self) -> int: ret = ret * dim return ret - def get_mlir_type(self, ctx: gc_mlir.ir.Context) -> gc_mlir.ir.Type: + def get_mlir_type(self, ctx: ir.Context) -> ir.Type: if self.scalar: return str_to_mlir_dtype(ctx, self.dtype) else: - return gc_mlir.ir.RankedTensorType.get( + return ir.RankedTensorType.get( self.shape, str_to_mlir_dtype(ctx, self.dtype) ) - def get_ranked_tensor_type( - self, ctx: gc_mlir.ir.Context - ) -> gc_mlir.ir.RankedTensorType: - return gc_mlir.ir.RankedTensorType.get( - self.shape, str_to_mlir_dtype(ctx, self.dtype) - ) + def get_ranked_tensor_type(self, ctx: ir.Context) -> ir.RankedTensorType: + return ir.RankedTensorType.get(self.shape, str_to_mlir_dtype(ctx, self.dtype)) - def get_constant_op( - self, ctx: gc_mlir.ir.Context, cst: Any - ) -> gc_mlir.dialects.tensor.OpView: + def get_constant_op(self, ctx: ir.Context, cst: Any) -> ir.OpView: zero = gc_mlir.dialects.arith.ConstantOp( value=str_to_mlir_typed_attr(ctx, self.dtype, cst), result=str_to_mlir_dtype(ctx, self.dtype), @@ -152,21 +146,17 @@ def get_constant_op( ], ) - def get_zero_op(self, ctx: gc_mlir.ir.Context) -> gc_mlir.dialects.tensor.OpView: + def get_zero_op(self, ctx: ir.Context) -> ir.OpView: return self.get_constant_op(ctx, 0) - def get_max_value_op( - self, ctx: gc_mlir.ir.Context - ) -> gc_mlir.dialects.tensor.OpView: + def get_max_value_op(self, ctx: ir.Context) -> ir.OpView: dtype = benchgc.util.get_dtype(self.dtype) if dtype.is_floating_point: return self.get_constant_op(ctx, torch.finfo(dtype).max) else: return self.get_constant_op(ctx, torch.iinfo(dtype).max) - def get_min_value_op( - self, ctx: gc_mlir.ir.Context - ) -> gc_mlir.dialects.tensor.OpView: + def get_min_value_op(self, ctx: ir.Context) -> ir.OpView: dtype = benchgc.util.get_dtype(self.dtype) if dtype.is_floating_point: return self.get_constant_op(ctx, torch.finfo(dtype).min) diff --git a/test/benchgc/src/benchgc/mlir/module.py b/test/benchgc/src/benchgc/mlir/module.py index 806c9d8b7..4dccb2928 100644 --- a/test/benchgc/src/benchgc/mlir/module.py +++ b/test/benchgc/src/benchgc/mlir/module.py @@ -17,8 +17,8 @@ from typing import Callable, List, Tuple import gc_mlir.dialects.tensor -import gc_mlir.ir from benchgc.mlir.arg import MLIRArg +from gc_mlir import ir from gc_mlir.dialects import func @@ -26,23 +26,23 @@ def init_module( inputs: Tuple[MLIRArg, ...], outputs: Tuple[MLIRArg, ...], op_func: Callable[ - [gc_mlir.ir.Context, Tuple[gc_mlir.ir.BlockArgument, ...]], - List[gc_mlir.ir.OpResult], + [ir.Context, Tuple[ir.BlockArgument, ...]], + List[ir.OpResult], ], -) -> gc_mlir.ir.Module: - with gc_mlir.ir.Context() as ctx, gc_mlir.ir.Location.unknown(): - module = gc_mlir.ir.Module.create() - with gc_mlir.ir.InsertionPoint(module.body): +) -> ir.Module: + with ir.Context() as ctx, ir.Location.unknown(): + module = ir.Module.create() + with ir.InsertionPoint(module.body): f = func.FuncOp( name="entry", - type=gc_mlir.ir.FunctionType.get( + type=ir.FunctionType.get( inputs=[x.get_mlir_type(ctx) for x in inputs], results=[x.get_mlir_type(ctx) for x in outputs], ), ) - f.attributes["llvm.emit_c_interface"] = gc_mlir.ir.UnitAttr.get() + f.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() - with gc_mlir.ir.InsertionPoint(f.add_entry_block()): + with ir.InsertionPoint(f.add_entry_block()): block_args = f.entry_block.arguments func.ReturnOp(op_func(ctx, *block_args)) return module diff --git a/test/benchgc/src/benchgc/mlir/util.py b/test/benchgc/src/benchgc/mlir/util.py index 846a81e19..ff8146f20 100644 --- a/test/benchgc/src/benchgc/mlir/util.py +++ b/test/benchgc/src/benchgc/mlir/util.py @@ -17,15 +17,15 @@ import ctypes from typing import Any, List -import gc_mlir.ir import torch -from gc_mlir.dialects import func, arith, memref +from gc_mlir import ir +from gc_mlir.dialects import arith, func, memref # only python 3.11 support # from typing import Self -def get_entry(module: gc_mlir.ir.Module, entry: str = '"entry"') -> func.FuncOp: +def get_entry(module: ir.Module, entry: str = '"entry"') -> func.FuncOp: for op in module.operation.opview.regions[0].blocks[0].operations: if str(op.name) == entry: return op @@ -72,41 +72,39 @@ def dtype_to_ctype(dtype: torch.dtype): raise ValueError(f"Unsupported torch dtype: {dtype}") -def str_to_mlir_dtype(ctx: gc_mlir.ir.Context, dtype: str) -> gc_mlir.ir.Type: +def str_to_mlir_dtype(ctx: ir.Context, dtype: str) -> ir.Type: if dtype == "f32": - return gc_mlir.ir.F32Type.get(ctx) + return ir.F32Type.get(ctx) elif dtype == "f64": - return gc_mlir.ir.F64Type.get(ctx) + return ir.F64Type.get(ctx) elif dtype == "f16": - return gc_mlir.ir.F16Type.get(ctx) + return ir.F16Type.get(ctx) elif dtype == "bf16": - return gc_mlir.ir.BF16Type.get(ctx) + return ir.BF16Type.get(ctx) elif dtype == "u8": - return gc_mlir.ir.IntegerType.get_unsigned(8, ctx) + return ir.IntegerType.get_unsigned(8, ctx) elif dtype == "s8": - return gc_mlir.ir.IntegerType.get_signed(8, ctx) + return ir.IntegerType.get_signed(8, ctx) elif dtype == "boolean": - return gc_mlir.ir.IntegerType.get_unsigned(1, ctx) + return ir.IntegerType.get_unsigned(1, ctx) elif dtype == "f8_e4m3": - return gc_mlir.ir.Float8E4M3FNType.get(ctx) + return ir.Float8E4M3FNType.get(ctx) elif dtype == "f8_e5m2": - return gc_mlir.ir.Float8E5M2Type.get(ctx) + return ir.Float8E5M2Type.get(ctx) elif dtype == "s32": - return gc_mlir.ir.IntegerType.get_signed(32, ctx) + return ir.IntegerType.get_signed(32, ctx) else: raise Exception(f"data type not support: {dtype}") -def str_to_mlir_typed_attr( - ctx: gc_mlir.ir.Context, dtype: str, value: Any -) -> gc_mlir.ir.Attribute: +def str_to_mlir_typed_attr(ctx: ir.Context, dtype: str, value: Any) -> ir.Attribute: mlir_dtype = str_to_mlir_dtype(ctx, dtype) if dtype in ["f32", "f64", "bf16", "f16", "f8_e4m3", "f8_e5m2"]: - return gc_mlir.ir.FloatAttr.get(mlir_dtype, value) + return ir.FloatAttr.get(mlir_dtype, value) elif dtype in ["u8", "s8", "s32"]: - return gc_mlir.ir.IntegerAttr.get(mlir_dtype, value) + return ir.IntegerAttr.get(mlir_dtype, value) elif dtype == "boolean": - return gc_mlir.ir.BoolAttr.get(value) + return ir.BoolAttr.get(value) else: raise Exception(f"data type not support: {dtype}") @@ -116,9 +114,9 @@ def str_to_mlir_typed_attr( def emit_nano_time() -> func.FuncOp: """Emit a nanoTime function that returns the current time in nanoseconds.""" nanoTime = func.FuncOp( - "nanoTime", ([], [gc_mlir.ir.IntegerType.get_signless(64)]), visibility="private" + "nanoTime", ([], [ir.IntegerType.get_signless(64)]), visibility="private" ) - nanoTime.attributes["llvm.emit_c_interface"] = gc_mlir.ir.UnitAttr.get() + nanoTime.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() return nanoTime @@ -126,7 +124,7 @@ def emit_benchmark_wrapped_main_func( kernel_func: func.FuncOp, timer_func: func.FuncOp ) -> func.FuncOp: """Emit a wrapped main function that calls the kernel function and records the time taken.""" - memref_of_i64_type = gc_mlir.ir.MemRefType.get([1], gc_mlir.ir.IntegerType.get_signless(64)) + memref_of_i64_type = ir.MemRefType.get([1], ir.IntegerType.get_signless(64)) wrapped_func_name = "wrapped_main" assert wrapped_func_name != str( kernel_func.name @@ -136,8 +134,8 @@ def emit_benchmark_wrapped_main_func( ([memref_of_i64_type] + kernel_func.arguments.types, kernel_func.type.results), visibility="public", ) - wrapped_func.attributes["llvm.emit_c_interface"] = gc_mlir.ir.UnitAttr.get() - with gc_mlir.ir.InsertionPoint(wrapped_func.add_entry_block()): + wrapped_func.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + with ir.InsertionPoint(wrapped_func.add_entry_block()): timer_buffer = wrapped_func.arguments[0] start = func.CallOp(timer_func, []) call_op = func.CallOp( @@ -153,7 +151,7 @@ def emit_benchmark_wrapped_main_func( def get_kernel_func_from_module( - module: gc_mlir.ir.Module, func_name: str = "main_entry" + module: ir.Module, func_name: str = "main_entry" ) -> func.FuncOp: """Get the func op by the name from a module""" assert ( diff --git a/test/benchgc/src/benchgc/pattern/mlp.py b/test/benchgc/src/benchgc/pattern/mlp.py index f27838f7a..9182594d2 100644 --- a/test/benchgc/src/benchgc/pattern/mlp.py +++ b/test/benchgc/src/benchgc/pattern/mlp.py @@ -19,13 +19,9 @@ from abc import ABC, abstractmethod from typing import List -import numpy as np +from benchgc.mlir.util import str_to_mlir_dtype from gc_mlir import ir from gc_mlir.dialects import arith, func, linalg, tensor -from gc_mlir.ir import BF16Type, FloatAttr -from benchgc.mlir.util import ( - str_to_mlir_dtype, -) def to_int_list(s: str) -> List[int]: @@ -70,7 +66,7 @@ def handle_args(self, args: argparse.Namespace): """Get and handle the args""" def __init__(self, ctx: ir.Context, args: argparse.Namespace): - self.main_entry = "main_entry" + self.main_entry = "entry" self.handle_args(args) self.ir_module = self.init_module(ctx) @@ -120,7 +116,7 @@ def init_module(self, ctx: ir.Context) -> ir.Module: with ctx, ir.Location.unknown(): layers = len(self.hidden_size_list) - 1 module = ir.Module.create() - dtype = str_to_mlir_dtype(self.dtype, ctx) + dtype = str_to_mlir_dtype(ctx, self.dtype) src = ir.RankedTensorType.get( [self.batch_size, self.hidden_size_list[0]], dtype ) @@ -186,7 +182,7 @@ def init_module(self, ctx: ir.Context) -> ir.Module: ) if self.act_type == "relu": - element = FloatAttr.get(dtype, 0) + element = ir.FloatAttr.get(dtype, 0) tensor_type = ir.RankedTensorType.get( layer_out_shape, dtype ) @@ -205,7 +201,7 @@ def add_args(parser: argparse.ArgumentParser): parser.add_argument("--hidden_size_list", type=str, default="") parser.add_argument("--has_bias", required=False, type=str) parser.add_argument( - "--act_type", type=str, choices=["noop", "relu", "sigmoid"], default="noop" + "--act_type", type=str, choices=["noop", "relu"], default="noop" ) parser.add_argument( "--dtype", @@ -240,7 +236,7 @@ def init_module(self, ctx: ir.Context) -> ir.Module: with ctx, ir.Location.unknown(): layers = len(self.hidden_size_list) - 1 module = ir.Module.create() - dtype = str_to_mlir_dtype(self.dtype, ctx) + dtype = str_to_mlir_dtype(ctx, self.dtype) src = ir.RankedTensorType.get( [self.batch_size, self.hidden_size_list[0]], dtype ) @@ -306,7 +302,7 @@ def init_module(self, ctx: ir.Context) -> ir.Module: ) if self.act_type == "relu": - element = FloatAttr.get(dtype, 0) + element = ir.FloatAttr.get(dtype, 0) tensor_type = ir.RankedTensorType.get( layer_out_shape, dtype ) @@ -316,4 +312,4 @@ def init_module(self, ctx: ir.Context) -> ir.Module: data, cst, outs=[tensor.EmptyOp(layer_out_shape, dtype)] ) func.ReturnOp([data]) - return module \ No newline at end of file + return module diff --git a/test/benchgc/src/benchgc/runner.py b/test/benchgc/src/benchgc/runner.py index 80178baa8..1f4e18e37 100644 --- a/test/benchgc/src/benchgc/runner.py +++ b/test/benchgc/src/benchgc/runner.py @@ -19,16 +19,16 @@ import gc_mlir._mlir_libs import gc_mlir.dialects import gc_mlir.dialects.func -import gc_mlir.ir import torch from benchgc.arith import ref_op as arith_ref_op from benchgc.linalg import ref_op as linalg_ref_op from benchgc.mlir.util import MLIRCache from benchgc.tensor import ref_op as tensor_ref_op +from gc_mlir import ir def dfs_op( - cache: MLIRCache, op: gc_mlir.ir.OpView, tensors: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, tensors: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: dialect_call: str = str(op.name) @@ -65,7 +65,7 @@ def dfs_op( def dfs_region( - cache: MLIRCache, region: gc_mlir.ir.Region, tensors: Dict[str, torch.Tensor] + cache: MLIRCache, region: ir.Region, tensors: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: build_cache = len(cache.next) == 0 for i in range(len(region.blocks)): @@ -82,7 +82,7 @@ def dfs_region( def dfs_block( - cache: MLIRCache, block: gc_mlir.ir.Block, tensors: Dict[str, torch.Tensor] + cache: MLIRCache, block: ir.Block, tensors: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: build_cache = len(cache.next) == 0 for i in range(len(block.operations)): diff --git a/test/benchgc/src/benchgc/tensor/__init__.py b/test/benchgc/src/benchgc/tensor/__init__.py index 2f8bc98a4..9bfd9b09d 100644 --- a/test/benchgc/src/benchgc/tensor/__init__.py +++ b/test/benchgc/src/benchgc/tensor/__init__.py @@ -18,21 +18,19 @@ import importlib from typing import Callable, Dict, Tuple -import gc_mlir.ir import torch from benchgc.arg import Arg from benchgc.mlir.util import MLIRCache +from gc_mlir import ir ref_op: Dict[ str, Callable[ - [MLIRCache, gc_mlir.ir.OpView, Dict[str, torch.Tensor]], + [MLIRCache, ir.OpView, Dict[str, torch.Tensor]], Tuple[torch.Tensor, ...], ], ] = {} -mlir_op: Dict[ - str, Callable[[argparse.Namespace, Dict[str, Arg]], gc_mlir.ir.Module] -] = {} +mlir_op: Dict[str, Callable[[argparse.Namespace, Dict[str, Arg]], ir.Module]] = {} for dri in ["basic", "shape"]: mod = importlib.import_module(f"benchgc.tensor.{dri}") diff --git a/test/benchgc/src/benchgc/tensor/basic.py b/test/benchgc/src/benchgc/tensor/basic.py index eb56aafbc..a424a7bb2 100644 --- a/test/benchgc/src/benchgc/tensor/basic.py +++ b/test/benchgc/src/benchgc/tensor/basic.py @@ -17,13 +17,13 @@ from typing import Dict, Tuple import benchgc.util -import gc_mlir.ir import torch from benchgc.mlir.util import MLIRCache +from gc_mlir import ir def ref_empty( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return ( torch.zeros( diff --git a/test/benchgc/src/benchgc/tensor/shape.py b/test/benchgc/src/benchgc/tensor/shape.py index 18d9fbb2c..25fc20e53 100644 --- a/test/benchgc/src/benchgc/tensor/shape.py +++ b/test/benchgc/src/benchgc/tensor/shape.py @@ -16,16 +16,16 @@ from typing import Dict, List, Tuple -import gc_mlir.ir import torch from benchgc.mlir.util import MLIRCache +from gc_mlir import ir def ref_collapse_shape( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: # permute axis and do reshape - reassociation: gc_mlir.ir.ArrayAttr = op.attributes["reassociation"] + reassociation: ir.ArrayAttr = op.attributes["reassociation"] permutation: List[int] = [] shape: List[int] = [] for outdim in reassociation: @@ -43,10 +43,10 @@ def ref_collapse_shape( def ref_expand_shape( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: # permute axis and do reshape - reassociation: gc_mlir.ir.ArrayAttr = op.attributes["reassociation"] + reassociation: ir.ArrayAttr = op.attributes["reassociation"] permutation: List[int] = [0] * len(op.result.type.shape) shape: List[int] = [] From 1e8b074787b35a26e06eb14ed4c1e58e7889847e Mon Sep 17 00:00:00 2001 From: "Xu, Rui" Date: Mon, 26 Aug 2024 23:18:57 -0700 Subject: [PATCH 07/38] fix format --- test/benchgc/src/benchgc/__main__.py | 16 +++++++++------- test/benchgc/src/benchgc/mlir/bench.py | 8 ++++---- test/benchgc/src/benchgc/pattern/mlp.py | 1 + 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/test/benchgc/src/benchgc/__main__.py b/test/benchgc/src/benchgc/__main__.py index 95f2f472c..f29416b4d 100644 --- a/test/benchgc/src/benchgc/__main__.py +++ b/test/benchgc/src/benchgc/__main__.py @@ -159,7 +159,7 @@ def add_common_options(parser: argparse.ArgumentParser): def add_bench_options(parser: argparse.ArgumentParser): - ''' add options for bench mode''' + """add options for bench mode""" if parser.parse_known_args()[0].mode == "P": parser.add_argument( "-p", @@ -183,7 +183,7 @@ def get_pattern_clz(driver_str: str): def add_pattern_options(parser: argparse.ArgumentParser): - '''add options for each pattern''' + """add options for each pattern""" if parser.parse_known_args()[0].driver == "pattern": pattern_name = parser.parse_known_args()[0].case get_pattern_clz(pattern_name).add_args(parser) @@ -265,6 +265,7 @@ def get_module_and_args(flags): print(module) return module, args + def correctness_testing(flags, module, args): ref_args: List[torch.Tensor] = [] gc_args: List[torch.Tensor | int] = [] @@ -319,7 +320,7 @@ def correctness_testing(flags, module, args): def performance_testing(flags, module, args): gc_args: List[torch.Tensor | int] = [] - gc_tensors: Dict[str, torch.Tensor] = {} + gc_tensors: Dict[str, torch.Tensor] = {} for i in range(len(args)): tensor = fill_tensor(flags, args[i], i) gc_tensors["%arg" + str(i)] = tensor @@ -327,11 +328,11 @@ def performance_testing(flags, module, args): gc_args.append(tensor.data_ptr()) else: gc_args.append(tensor) - + mlir_args = get_mlir_args(gc_args) - with module.context as ctx: + with module.context as ctx: if flags.print_ir: - ctx.enable_multithreading(False) + ctx.enable_multithreading(False) bench_kind = py_timeit_bench if flags.bench_kind == "py" else mlir_wrapper_bench execute_cost, compile_cost = bench_kind( module, @@ -341,7 +342,7 @@ def performance_testing(flags, module, args): flags.print_ir, flags.repeat, flags.warm_up, - ) + ) print("===========bench result===========") json_res = json.dumps( { @@ -353,6 +354,7 @@ def performance_testing(flags, module, args): ) print(json_res) + if __name__ == "__main__": arg_parser = argparse.ArgumentParser(prog="benchmark tool for graph compiler") add_common_options(arg_parser) diff --git a/test/benchgc/src/benchgc/mlir/bench.py b/test/benchgc/src/benchgc/mlir/bench.py index b005775a4..483ff0023 100644 --- a/test/benchgc/src/benchgc/mlir/bench.py +++ b/test/benchgc/src/benchgc/mlir/bench.py @@ -18,16 +18,16 @@ import ctypes import random import timeit -from typing import List, Sequence, Tuple +from typing import List, Tuple import numpy as np -from gc_mlir import ir, runtime -from gc_mlir.graph_compiler import GraphCompiler from benchgc.mlir.util import ( emit_benchmark_wrapped_main_func, emit_nano_time, get_kernel_func_from_module, ) +from gc_mlir import ir, runtime +from gc_mlir.graph_compiler import GraphCompiler def py_timeit_bench( @@ -203,4 +203,4 @@ def run(engine_invoke, bench_func_name, *mlir_args): execute_cost = total_time / repeat_time execute_costs.append(execute_cost) - return list(zip(compile_costs, execute_costs)) \ No newline at end of file + return list(zip(compile_costs, execute_costs)) diff --git a/test/benchgc/src/benchgc/pattern/mlp.py b/test/benchgc/src/benchgc/pattern/mlp.py index 9182594d2..847135d4f 100644 --- a/test/benchgc/src/benchgc/pattern/mlp.py +++ b/test/benchgc/src/benchgc/pattern/mlp.py @@ -53,6 +53,7 @@ def to_bool_list(s: str) -> List[bool]: return [] return [bool(int(i)) for i in s.strip().split("x")] + class Pattern(ABC): """Abstract class for driver.""" From 1c201849ad99d81efa76cb77239cde3aef901ac9 Mon Sep 17 00:00:00 2001 From: "Xu, Rui" Date: Mon, 26 Aug 2024 23:24:00 -0700 Subject: [PATCH 08/38] fix format --- test/benchgc/src/benchgc/mlir/util.py | 2 - test/benchgc/src/benchgc/pattern/__init__.py | 2 +- test/benchgc/src/benchgc/pattern/mlp.py | 121 +------------------ 3 files changed, 2 insertions(+), 123 deletions(-) diff --git a/test/benchgc/src/benchgc/mlir/util.py b/test/benchgc/src/benchgc/mlir/util.py index ff8146f20..d6ef7ea75 100644 --- a/test/benchgc/src/benchgc/mlir/util.py +++ b/test/benchgc/src/benchgc/mlir/util.py @@ -109,8 +109,6 @@ def str_to_mlir_typed_attr(ctx: ir.Context, dtype: str, value: Any) -> ir.Attrib raise Exception(f"data type not support: {dtype}") - - def emit_nano_time() -> func.FuncOp: """Emit a nanoTime function that returns the current time in nanoseconds.""" nanoTime = func.FuncOp( diff --git a/test/benchgc/src/benchgc/pattern/__init__.py b/test/benchgc/src/benchgc/pattern/__init__.py index e97242f09..ba0b8696a 100644 --- a/test/benchgc/src/benchgc/pattern/__init__.py +++ b/test/benchgc/src/benchgc/pattern/__init__.py @@ -14,7 +14,7 @@ # limitations under the License. ################################################################################ -import sys import pathlib +import sys sys.path.append(pathlib.Path(__file__).parent.resolve().__str__()) \ No newline at end of file diff --git a/test/benchgc/src/benchgc/pattern/mlp.py b/test/benchgc/src/benchgc/pattern/mlp.py index 847135d4f..c75c96b09 100644 --- a/test/benchgc/src/benchgc/pattern/mlp.py +++ b/test/benchgc/src/benchgc/pattern/mlp.py @@ -75,6 +75,7 @@ def __init__(self, ctx: ir.Context, args: argparse.Namespace): def init_module(self, ctx: ir.Context) -> ir.Module: """Create MLIR moudule by args""" + class MLP(Pattern): @staticmethod def add_args(parser: argparse.ArgumentParser): @@ -194,123 +195,3 @@ def init_module(self, ctx: ir.Context) -> ir.Module: ) func.ReturnOp([data]) return module - - - @staticmethod - def add_args(parser: argparse.ArgumentParser): - parser.add_argument("--batch_size", type=int, default=1) - parser.add_argument("--hidden_size_list", type=str, default="") - parser.add_argument("--has_bias", required=False, type=str) - parser.add_argument( - "--act_type", type=str, choices=["noop", "relu"], default="noop" - ) - parser.add_argument( - "--dtype", - type=str, - choices=[ - "f32", - "bf16", - ], - default="f32", - ) - - def handle_args(self, args: argparse.Namespace): - self.batch_size = args.batch_size - assert self.batch_size > 0, "batch size should be greater than 0" - - self.hidden_size_list = to_int_list(args.hidden_size_list) - layers = len(self.hidden_size_list) - 1 - assert layers >= 1, "hidden_size_list should have at least 2 elements" - - self.has_bias = ( - [False] * layers if args.has_bias is None else to_bool_list(args.has_bias) - ) - - assert ( - len(self.has_bias) == layers - ), "has_bias should have the same length as hidden_size_list" - - self.act_type = args.act_type - self.dtype = args.dtype - - def init_module(self, ctx: ir.Context) -> ir.Module: - with ctx, ir.Location.unknown(): - layers = len(self.hidden_size_list) - 1 - module = ir.Module.create() - dtype = str_to_mlir_dtype(ctx, self.dtype) - src = ir.RankedTensorType.get( - [self.batch_size, self.hidden_size_list[0]], dtype - ) - weights = [] - bias = [] - for i in range(layers): - weights.append( - ir.RankedTensorType.get( - [ - self.hidden_size_list[i], - self.hidden_size_list[i + 1], - ], - dtype, - ) - ) - if self.has_bias[i]: - bias.append( - ir.RankedTensorType.get([self.hidden_size_list[i + 1]], dtype) - ) - result = ir.RankedTensorType.get( - [ - self.batch_size, - self.hidden_size_list[-1], - ], - dtype, - ) - with ir.InsertionPoint(module.body): - f = func.FuncOp( - name=self.main_entry, - type=ir.FunctionType.get( - inputs=[src] + weights + bias, results=[result] - ), - ) - f.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() - with ir.InsertionPoint(f.add_entry_block()): - data = f.entry_block.arguments[0] - bias_idx = len(weights) + 1 - for i in range(layers): - weight = f.entry_block.arguments[i + 1] - if self.has_bias[i]: - bias = f.entry_block.arguments[bias_idx] - bias_idx += 1 - else: - bias = None - layer_out_shape = [ - self.batch_size, - self.hidden_size_list[i + 1], - ] - - data = linalg.matmul( - data, weight, outs=[tensor.EmptyOp(layer_out_shape, dtype)] - ) - if bias: - broadcast_bias = linalg.broadcast( - bias, - outs=[tensor.EmptyOp(layer_out_shape, dtype)], - dimensions=[0], - ) - data = linalg.add( - data, - broadcast_bias, - outs=[tensor.EmptyOp(layer_out_shape, dtype)], - ) - - if self.act_type == "relu": - element = ir.FloatAttr.get(dtype, 0) - tensor_type = ir.RankedTensorType.get( - layer_out_shape, dtype - ) - attr = ir.DenseElementsAttr.get_splat(tensor_type, element) - cst = arith.ConstantOp(tensor_type, attr) - data = linalg.max( - data, cst, outs=[tensor.EmptyOp(layer_out_shape, dtype)] - ) - func.ReturnOp([data]) - return module From 69f2e9460077ba3b47fa3032c1b9345d247b9b7d Mon Sep 17 00:00:00 2001 From: "Xu, Rui" Date: Tue, 27 Aug 2024 00:06:55 -0700 Subject: [PATCH 09/38] reorg the pattern dir --- test/benchgc/src/benchgc/__main__.py | 7 +-- test/benchgc/src/benchgc/pattern/__init__.py | 11 ++-- test/benchgc/src/benchgc/pattern/base.py | 43 +++++++++++++++ test/benchgc/src/benchgc/pattern/mlp.py | 55 +------------------- test/benchgc/src/benchgc/pattern/util.py | 48 +++++++++++++++++ 5 files changed, 102 insertions(+), 62 deletions(-) create mode 100644 test/benchgc/src/benchgc/pattern/base.py create mode 100644 test/benchgc/src/benchgc/pattern/util.py diff --git a/test/benchgc/src/benchgc/__main__.py b/test/benchgc/src/benchgc/__main__.py index f29416b4d..736b9f2a1 100644 --- a/test/benchgc/src/benchgc/__main__.py +++ b/test/benchgc/src/benchgc/__main__.py @@ -33,7 +33,7 @@ from benchgc.arg.arg import Arg from benchgc.mlir.arg import get_mlir_args from benchgc.mlir.bench import mlir_wrapper_bench, py_timeit_bench -from benchgc.pattern.mlp import MLP +from benchgc.pattern import get_pattern_clz from gc_mlir import ir from gc_mlir.graph_compiler import GraphCompiler @@ -176,11 +176,6 @@ def add_bench_options(parser: argparse.ArgumentParser): parser.add_argument("--entry", type=str, default="main_entry") -def get_pattern_clz(driver_str: str): - """Function getting Pattern class by name.""" - clz = {"mlp": MLP}[driver_str] - return clz - def add_pattern_options(parser: argparse.ArgumentParser): """add options for each pattern""" diff --git a/test/benchgc/src/benchgc/pattern/__init__.py b/test/benchgc/src/benchgc/pattern/__init__.py index ba0b8696a..c9c035202 100644 --- a/test/benchgc/src/benchgc/pattern/__init__.py +++ b/test/benchgc/src/benchgc/pattern/__init__.py @@ -14,7 +14,12 @@ # limitations under the License. ################################################################################ -import pathlib -import sys +from .base import Pattern +from .mlp import MLP -sys.path.append(pathlib.Path(__file__).parent.resolve().__str__()) \ No newline at end of file +__all__ = ["Pattern", "MLP", "get_pattern_clz"] + +def get_pattern_clz(name: str): + """Function getting pattern class by name.""" + clz = {"mlp": MLP}[name] + return clz \ No newline at end of file diff --git a/test/benchgc/src/benchgc/pattern/base.py b/test/benchgc/src/benchgc/pattern/base.py new file mode 100644 index 000000000..e0f7d7618 --- /dev/null +++ b/test/benchgc/src/benchgc/pattern/base.py @@ -0,0 +1,43 @@ +################################################################################ +# Copyright (C) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# SPDX-License-Identifier: Apache-2.0 +################################################################################ + +import argparse +from abc import ABC, abstractmethod + +from gc_mlir import ir + + +class Pattern(ABC): + """Abstract class for pattern.""" + + @staticmethod + @abstractmethod + def add_args(parser: argparse.ArgumentParser): + """Add arguments to parser""" + + @abstractmethod + def handle_args(self, args: argparse.Namespace): + """Get and handle the args""" + + def __init__(self, ctx: ir.Context, args: argparse.Namespace): + self.main_entry = "entry" + self.handle_args(args) + self.ir_module = self.init_module(ctx) + + @abstractmethod + def init_module(self, ctx: ir.Context) -> ir.Module: + """Create MLIR moudule by args""" diff --git a/test/benchgc/src/benchgc/pattern/mlp.py b/test/benchgc/src/benchgc/pattern/mlp.py index c75c96b09..2f4bea774 100644 --- a/test/benchgc/src/benchgc/pattern/mlp.py +++ b/test/benchgc/src/benchgc/pattern/mlp.py @@ -16,64 +16,13 @@ ################################################################################ import argparse -from abc import ABC, abstractmethod -from typing import List from benchgc.mlir.util import str_to_mlir_dtype from gc_mlir import ir from gc_mlir.dialects import arith, func, linalg, tensor - -def to_int_list(s: str) -> List[int]: - """ - Parsing the cmd for list of int values - - Args: - s (str): int values in cmd, example: 2x3x4 - - Returns: - List[int]: int values in list, example: [2, 3, 4] - """ - if not s or len(s) == 0: - return [] - return [int(i) for i in s.strip().split("x")] - - -def to_bool_list(s: str) -> List[bool]: - """ - Parsing the cmd for list of bool values - - Args: - s (str): bools in cmd, example: 1x0x1 - - Returns: - List[bool]: bools in list, example: [True, False, True] - """ - if not s or len(s) == 0: - return [] - return [bool(int(i)) for i in s.strip().split("x")] - - -class Pattern(ABC): - """Abstract class for driver.""" - - @staticmethod - @abstractmethod - def add_args(parser: argparse.ArgumentParser): - """Add arguments to parser""" - - @abstractmethod - def handle_args(self, args: argparse.Namespace): - """Get and handle the args""" - - def __init__(self, ctx: ir.Context, args: argparse.Namespace): - self.main_entry = "entry" - self.handle_args(args) - self.ir_module = self.init_module(ctx) - - @abstractmethod - def init_module(self, ctx: ir.Context) -> ir.Module: - """Create MLIR moudule by args""" +from .base import Pattern +from .util import to_bool_list, to_int_list class MLP(Pattern): diff --git a/test/benchgc/src/benchgc/pattern/util.py b/test/benchgc/src/benchgc/pattern/util.py new file mode 100644 index 000000000..62bf74ca5 --- /dev/null +++ b/test/benchgc/src/benchgc/pattern/util.py @@ -0,0 +1,48 @@ +################################################################################ +# Copyright (C) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# SPDX-License-Identifier: Apache-2.0 +################################################################################ + +from typing import List + + +def to_int_list(s: str) -> List[int]: + """ + Parsing the cmd for list of int values + + Args: + s (str): int values in cmd, example: 2x3x4 + + Returns: + List[int]: int values in list, example: [2, 3, 4] + """ + if not s or len(s) == 0: + return [] + return [int(i) for i in s.strip().split("x")] + + +def to_bool_list(s: str) -> List[bool]: + """ + Parsing the cmd for list of bool values + + Args: + s (str): bools in cmd, example: 1x0x1 + + Returns: + List[bool]: bools in list, example: [True, False, True] + """ + if not s or len(s) == 0: + return [] + return [bool(int(i)) for i in s.strip().split("x")] From 8d0953c696217100a1e9a89314c1cc7860a73357 Mon Sep 17 00:00:00 2001 From: "Xu, Rui" Date: Tue, 27 Aug 2024 01:48:39 -0700 Subject: [PATCH 10/38] improve --- test/benchgc/src/benchgc/__main__.py | 118 +++++++++++---------- test/benchgc/src/benchgc/linalg/binary.py | 7 ++ test/benchgc/src/benchgc/linalg/conv.py | 22 ++++ test/benchgc/src/benchgc/linalg/eltwise.py | 12 +++ test/benchgc/src/benchgc/linalg/matmul.py | 14 +++ test/benchgc/src/benchgc/linalg/misc.py | 3 + test/benchgc/src/benchgc/linalg/pool.py | 12 +++ test/benchgc/src/benchgc/linalg/softmax.py | 1 + test/benchgc/src/benchgc/mlir/module.py | 4 +- test/benchgc/src/benchgc/mlir/util.py | 12 +-- test/benchgc/src/benchgc/pattern/base.py | 6 +- test/benchgc/src/benchgc/util.py | 1 + 12 files changed, 141 insertions(+), 71 deletions(-) diff --git a/test/benchgc/src/benchgc/__main__.py b/test/benchgc/src/benchgc/__main__.py index 736b9f2a1..6f0edcfa5 100644 --- a/test/benchgc/src/benchgc/__main__.py +++ b/test/benchgc/src/benchgc/__main__.py @@ -107,67 +107,71 @@ def add_common_options(parser: argparse.ArgumentParser): benchgc.util.ERROR_OUTPUT_VERBOSE, benchgc.util.OUTPUT_VERBOSE, benchgc.util.INPUT_VERBOSE, + benchgc.util.PIPELINE_VERBOSE, ], ) + parser.add_argument( - "--cast", + "--entry", required=False, - default="cast_signed", - help="define attribute supported by linalg op such as matmul_transpose_b", - choices=["cast_signed", "cast_unsigned"], + default="entry", + help="the entry func name of a mlir", type=str, ) - # single dimension index - # linalg.softmax - parser.add_argument( - "--dimension", - required=False, - default=None, - help="define the dimension attribute in linalg op", - type=int, - ) + if parser.parse_known_args()[0].driver == "linalg": + parser.add_argument( + "--cast", + required=False, + default="cast_signed", + help="define attribute supported by linalg op such as matmul_transpose_b", + choices=["cast_signed", "cast_unsigned"], + type=str, + ) - # multiple dimensions array - # linalg.broadcast / linalg.reduce - parser.add_argument( - "--dimensions", - required=False, - default=None, - action="append", - help="define the dimensions attribute in linalg op", - type=int, - ) + # single dimension index + # linalg.softmax + parser.add_argument( + "--dimension", + required=False, + default=None, + help="define the dimension attribute in linalg op", + type=int, + ) - parser.add_argument( - "--dilations", - required=False, - default=None, - action="append", - help="define the dilations attribute in linalg op", - type=int, - ) + # multiple dimensions array + # linalg.broadcast / linalg.reduce + parser.add_argument( + "--dimensions", + required=False, + default=None, + action="append", + help="define the dimensions attribute in linalg op", + type=int, + ) - parser.add_argument( - "--strides", - required=False, - default=None, - action="append", - help="define the strides attribute in linalg op", - type=int, - ) + parser.add_argument( + "--dilations", + required=False, + default=None, + action="append", + help="define the dilations attribute in linalg op", + type=int, + ) + + parser.add_argument( + "--strides", + required=False, + default=None, + action="append", + help="define the strides attribute in linalg op", + type=int, + ) def add_bench_options(parser: argparse.ArgumentParser): """add options for bench mode""" if parser.parse_known_args()[0].mode == "P": - parser.add_argument( - "-p", - "--print_ir", - action="store_true", - help="if need print the IR after pipeline", - required=False, - ) parser.add_argument( "--bench_kind", type=str, choices=["py", "wrapper"], default="py" ) @@ -196,7 +200,7 @@ def get_module_and_args(flags): pattern_clz = get_pattern_clz(flags.case) module = pattern_clz(ctx, flags).ir_module - entry = benchgc.mlir.util.get_entry(module) + entry = benchgc.mlir.util.get_kernel_func_from_module(module, flags.entry) idx: int = 0 # FIXME: only support RankTensorType now for i in entry.type.inputs: @@ -246,7 +250,7 @@ def get_module_and_args(flags): idx = int(cmp[:colon]) args[idx].set_cmp(cmp[colon + 1 :]) - entry = benchgc.mlir.util.get_entry(module) + entry = benchgc.mlir.util.get_kernel_func_from_module(module, flags.entry) for i, arg in enumerate(args): # use zero filling if the arg is return value @@ -277,7 +281,7 @@ def correctness_testing(flags, module, args): else: gc_args.append(tensor) - entry = benchgc.mlir.util.get_entry(module) + entry = benchgc.mlir.util.get_kernel_func_from_module(module, flags.entry) # ref_out contains return value of the entry ref_out = runner.ref_run(entry, ref_tensors) @@ -289,10 +293,13 @@ def correctness_testing(flags, module, args): mlir_args = get_mlir_args(gc_args) passes = "any(gc-cpu-pipeline)" - with module.context: + with module.context as ctx: + ir_printing = flags.verbose >= benchgc.util.PIPELINE_VERBOSE + if ir_printing: + ctx.enable_multithreading(False) compiler = GraphCompiler(passes) - engine = compiler.compile_and_jit(module) - engine.invoke("entry", *mlir_args) + engine = compiler.compile_and_jit(module, ir_printing) + engine.invoke(flags.entry, *mlir_args) fail, mistrust = False, False for i in range(len(args)): @@ -326,15 +333,16 @@ def performance_testing(flags, module, args): mlir_args = get_mlir_args(gc_args) with module.context as ctx: - if flags.print_ir: + ir_printing = flags.verbose >= benchgc.util.PIPELINE_VERBOSE + if ir_printing: ctx.enable_multithreading(False) bench_kind = py_timeit_bench if flags.bench_kind == "py" else mlir_wrapper_bench execute_cost, compile_cost = bench_kind( module, - "entry", + flags.entry, "any(gc-cpu-pipeline)", mlir_args, - flags.print_ir, + ir_printing, flags.repeat, flags.warm_up, ) diff --git a/test/benchgc/src/benchgc/linalg/binary.py b/test/benchgc/src/benchgc/linalg/binary.py index b8508d8a0..66f3a7abe 100644 --- a/test/benchgc/src/benchgc/linalg/binary.py +++ b/test/benchgc/src/benchgc/linalg/binary.py @@ -33,6 +33,7 @@ def ref_add( def mlir_add(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -49,6 +50,7 @@ def ref_powf( def mlir_powf(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -65,6 +67,7 @@ def ref_div( def mlir_div(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -81,6 +84,7 @@ def ref_max( def mlir_max(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -97,6 +101,7 @@ def ref_min( def mlir_min(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -113,6 +118,7 @@ def ref_mul( def mlir_mul(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -129,6 +135,7 @@ def ref_sub( def mlir_sub(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ diff --git a/test/benchgc/src/benchgc/linalg/conv.py b/test/benchgc/src/benchgc/linalg/conv.py index 359d0d586..f1f8d7f97 100644 --- a/test/benchgc/src/benchgc/linalg/conv.py +++ b/test/benchgc/src/benchgc/linalg/conv.py @@ -42,6 +42,7 @@ def ref_conv_1d_ncw_fcw( def mlir_conv_1d_ncw_fcw(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -80,6 +81,7 @@ def ref_conv_1d_nwc_wcf( def mlir_conv_1d_nwc_wcf(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -111,6 +113,7 @@ def ref_conv_1d_ncw_fcw( def mlir_conv_1d_ncw_fcw(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -140,6 +143,7 @@ def ref_conv_1d( def mlir_conv_1d(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -169,6 +173,7 @@ def ref_conv_2d_nchw_fchw( def mlir_conv_2d_nchw_fchw(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -215,6 +220,7 @@ def ref_conv_2d_ngchw_fgchw( def mlir_conv_2d_ngchw_fgchw(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -259,6 +265,7 @@ def ref_conv_2d_ngchw_gfchw( def mlir_conv_2d_ngchw_gfchw(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -292,6 +299,7 @@ def ref_conv_2d_nhwc_fhwc( def mlir_conv_2d_nhwc_fhwc(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -325,6 +333,7 @@ def ref_conv_2d_nhwc_hwcf( def mlir_conv_2d_nhwc_hwcf(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -354,6 +363,7 @@ def ref_conv_2d( def mlir_conv_2d(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -383,6 +393,7 @@ def ref_conv_3d_ncdhw_fcdhw( def mlir_conv_3d_ncdhw_fcdhw(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -416,6 +427,7 @@ def ref_conv_3d_ndhwc_dhwcf( def mlir_conv_3d_ndhwc_dhwcf(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -445,6 +457,7 @@ def ref_conv_3d( def mlir_conv_3d(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -478,6 +491,7 @@ def mlir_depthwise_conv_1d_ncw_cw( flags: argparse.Namespace, args: List[Arg] ) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -515,6 +529,7 @@ def mlir_depthwise_conv_1d_nwc_wc( flags: argparse.Namespace, args: List[Arg] ) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -557,6 +572,7 @@ def mlir_depthwise_conv_1d_nwc_wcm( flags: argparse.Namespace, args: List[Arg] ) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -592,6 +608,7 @@ def mlir_depthwise_conv_2d_nchw_chw( flags: argparse.Namespace, args: List[Arg] ) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -629,6 +646,7 @@ def mlir_depthwise_conv_2d_nhwc_hwc( flags: argparse.Namespace, args: List[Arg] ) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -674,6 +692,7 @@ def mlir_depthwise_conv_2d_nhwc_hwcm( flags: argparse.Namespace, args: List[Arg] ) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -709,6 +728,7 @@ def mlir_depthwise_conv_3d_ncdhw_cdhw( flags: argparse.Namespace, args: List[Arg] ) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -746,6 +766,7 @@ def mlir_depthwise_conv_3d_ndhwc_dhwc( flags: argparse.Namespace, args: List[Arg] ) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -800,6 +821,7 @@ def mlir_depthwise_conv_3d_ndhwc_dhwcm( flags: argparse.Namespace, args: List[Arg] ) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ diff --git a/test/benchgc/src/benchgc/linalg/eltwise.py b/test/benchgc/src/benchgc/linalg/eltwise.py index d6e5a9c5d..760fcb0a1 100644 --- a/test/benchgc/src/benchgc/linalg/eltwise.py +++ b/test/benchgc/src/benchgc/linalg/eltwise.py @@ -33,6 +33,7 @@ def ref_abs( def mlir_abs(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0],), (args[1],), lambda ctx, arg0: [linalg.abs(arg0, outs=[args[1].get_zero_op(ctx)])], @@ -47,6 +48,7 @@ def ref_ceil( def mlir_ceil(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0],), (args[1],), lambda ctx, arg0: [linalg.ceil(arg0, outs=[args[1].get_zero_op(ctx)])], @@ -61,6 +63,7 @@ def ref_floor( def mlir_floor(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0],), (args[1],), lambda ctx, arg0: [linalg.floor(arg0, outs=[args[1].get_zero_op(ctx)])], @@ -75,6 +78,7 @@ def ref_erf( def mlir_erf(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0],), (args[1],), lambda ctx, arg0: [linalg.erf(arg0, outs=[args[1].get_zero_op(ctx)])], @@ -83,6 +87,7 @@ def mlir_erf(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: def mlir_log(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0],), (args[1],), lambda ctx, arg0: [linalg.log(arg0, outs=[args[1].get_zero_op(ctx)])], @@ -97,6 +102,7 @@ def ref_log( def mlir_negf(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0],), (args[1],), lambda ctx, arg0: [linalg.negf(arg0, outs=[args[1].get_zero_op(ctx)])], @@ -117,6 +123,7 @@ def ref_exp( def mlir_exp(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0],), (args[1],), lambda ctx, arg0: [linalg.negf(arg0, outs=[args[1].get_zero_op(ctx)])], @@ -135,6 +142,7 @@ def ref_round( def mlir_round(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0],), (args[1],), lambda ctx, arg0: [linalg.round(arg0, outs=[args[1].get_zero_op(ctx)])], @@ -149,6 +157,7 @@ def ref_rsqrt( def mlir_rsqrt(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0],), (args[1],), lambda ctx, arg0: [linalg.rsqrt(arg0, outs=[args[1].get_zero_op(ctx)])], @@ -163,6 +172,7 @@ def ref_sqrt( def mlir_sqrt(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0],), (args[1],), lambda ctx, arg0: [linalg.sqrt(arg0, outs=[args[1].get_zero_op(ctx)])], @@ -177,6 +187,7 @@ def ref_square( def mlir_square(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0],), (args[1],), lambda ctx, arg0: [linalg.square(arg0, outs=[args[1].get_zero_op(ctx)])], @@ -191,6 +202,7 @@ def ref_tanh( def mlir_tanh(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0],), (args[1],), lambda ctx, arg0: [linalg.tanh(arg0, outs=[args[1].get_zero_op(ctx)])], diff --git a/test/benchgc/src/benchgc/linalg/matmul.py b/test/benchgc/src/benchgc/linalg/matmul.py index 1c36d9dbd..16ad0519c 100644 --- a/test/benchgc/src/benchgc/linalg/matmul.py +++ b/test/benchgc/src/benchgc/linalg/matmul.py @@ -34,6 +34,7 @@ def ref_batch_matmul( def mlir_batch_matmul(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -52,6 +53,7 @@ def mlir_batch_matmul_transpose_a( flags: argparse.Namespace, args: List[Arg] ) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -70,6 +72,7 @@ def mlir_batch_matmul_transpose_b( flags: argparse.Namespace, args: List[Arg] ) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -89,6 +92,7 @@ def ref_batch_matvec( def mlir_batch_matvec(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -126,6 +130,7 @@ def ref_batch_mmt4d( def mlir_batch_mmt4d(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -150,6 +155,7 @@ def ref_batch_reduce_matmul( def mlir_batch_reduce_matmul(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -168,6 +174,7 @@ def ref_batch_vecmat( def mlir_batch_vecmat(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -184,6 +191,7 @@ def ref_dot( def mlir_dot(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -200,6 +208,7 @@ def ref_matmul( def mlir_matmul(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -218,6 +227,7 @@ def ref_matmul_transpose_a( def mlir_matmul_transpose_a(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -236,6 +246,7 @@ def ref_matmul_transpose_b( def mlir_matmul_transpose_b(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -254,6 +265,7 @@ def ref_matvec( def mlir_matvec(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -285,6 +297,7 @@ def ref_mmt4d( def mlir_mmt4d(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -303,6 +316,7 @@ def ref_vecmat( def mlir_vecmat(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ diff --git a/test/benchgc/src/benchgc/linalg/misc.py b/test/benchgc/src/benchgc/linalg/misc.py index 6020e54db..05f8ebbbe 100644 --- a/test/benchgc/src/benchgc/linalg/misc.py +++ b/test/benchgc/src/benchgc/linalg/misc.py @@ -45,6 +45,7 @@ def ref_broadcast( def mlir_broadcast(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0],), (args[1],), lambda ctx, arg0: [ @@ -63,6 +64,7 @@ def ref_fill( def mlir_fill(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0],), (args[1],), lambda ctx, arg0: [ @@ -86,6 +88,7 @@ def ref_copy( def mlir_copy(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0],), (args[1],), lambda ctx, arg0: [ diff --git a/test/benchgc/src/benchgc/linalg/pool.py b/test/benchgc/src/benchgc/linalg/pool.py index 1b2f03412..755e4c76a 100644 --- a/test/benchgc/src/benchgc/linalg/pool.py +++ b/test/benchgc/src/benchgc/linalg/pool.py @@ -42,6 +42,7 @@ def ref_pooling_nchw_max( def mlir_pooling_nchw_max(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -83,6 +84,7 @@ def ref_pooling_nchw_sum( def mlir_pooling_nchw_sum(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -114,6 +116,7 @@ def ref_pooling_ncw_max( def mlir_pooling_ncw_max(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -155,6 +158,7 @@ def ref_pooling_ncw_sum( def mlir_pooling_ncw_sum(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -188,6 +192,7 @@ def ref_pooling_ndhwc_max( def mlir_pooling_ndhwc_max(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -231,6 +236,7 @@ def ref_pooling_ndhwc_sum( def mlir_pooling_ndhwc_sum(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -264,6 +270,7 @@ def ref_pooling_nhwc_max( def mlir_pooling_nhwc_max(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -307,6 +314,7 @@ def ref_pooling_nhwc_sum( def mlir_pooling_nhwc_sum(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -341,6 +349,7 @@ def ref_pooling_nhwc_min( def mlir_pooling_nhwc_min(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -374,6 +383,7 @@ def ref_pooling_nwc_max( def mlir_pooling_nwc_max(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -408,6 +418,7 @@ def ref_pooling_nwc_min( def mlir_pooling_nwc_min(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -451,6 +462,7 @@ def ref_pooling_nwc_sum( def mlir_pooling_nwc_sum(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ diff --git a/test/benchgc/src/benchgc/linalg/softmax.py b/test/benchgc/src/benchgc/linalg/softmax.py index fe18a0017..e56376404 100644 --- a/test/benchgc/src/benchgc/linalg/softmax.py +++ b/test/benchgc/src/benchgc/linalg/softmax.py @@ -34,6 +34,7 @@ def ref_softmax( def mlir_softmax(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0],), (args[1],), lambda ctx, arg0: [ diff --git a/test/benchgc/src/benchgc/mlir/module.py b/test/benchgc/src/benchgc/mlir/module.py index 4dccb2928..4633c1f25 100644 --- a/test/benchgc/src/benchgc/mlir/module.py +++ b/test/benchgc/src/benchgc/mlir/module.py @@ -16,13 +16,13 @@ from typing import Callable, List, Tuple -import gc_mlir.dialects.tensor from benchgc.mlir.arg import MLIRArg from gc_mlir import ir from gc_mlir.dialects import func def init_module( + entry_name: str, inputs: Tuple[MLIRArg, ...], outputs: Tuple[MLIRArg, ...], op_func: Callable[ @@ -34,7 +34,7 @@ def init_module( module = ir.Module.create() with ir.InsertionPoint(module.body): f = func.FuncOp( - name="entry", + name=get_entry_name(), type=ir.FunctionType.get( inputs=[x.get_mlir_type(ctx) for x in inputs], results=[x.get_mlir_type(ctx) for x in outputs], diff --git a/test/benchgc/src/benchgc/mlir/util.py b/test/benchgc/src/benchgc/mlir/util.py index d6ef7ea75..faaf4daef 100644 --- a/test/benchgc/src/benchgc/mlir/util.py +++ b/test/benchgc/src/benchgc/mlir/util.py @@ -21,16 +21,6 @@ from gc_mlir import ir from gc_mlir.dialects import arith, func, memref -# only python 3.11 support -# from typing import Self - - -def get_entry(module: ir.Module, entry: str = '"entry"') -> func.FuncOp: - for op in module.operation.opview.regions[0].blocks[0].operations: - if str(op.name) == entry: - return op - raise Exception(f"entry function {entry} is not found at the top level") - # calling python binding consumes a lot of time e.g. get_name() # we need to cache some result to avoid duplicate call @@ -149,7 +139,7 @@ def emit_benchmark_wrapped_main_func( def get_kernel_func_from_module( - module: ir.Module, func_name: str = "main_entry" + module: ir.Module, func_name: str = "entry" ) -> func.FuncOp: """Get the func op by the name from a module""" assert ( diff --git a/test/benchgc/src/benchgc/pattern/base.py b/test/benchgc/src/benchgc/pattern/base.py index e0f7d7618..42527efc3 100644 --- a/test/benchgc/src/benchgc/pattern/base.py +++ b/test/benchgc/src/benchgc/pattern/base.py @@ -33,9 +33,9 @@ def add_args(parser: argparse.ArgumentParser): def handle_args(self, args: argparse.Namespace): """Get and handle the args""" - def __init__(self, ctx: ir.Context, args: argparse.Namespace): - self.main_entry = "entry" - self.handle_args(args) + def __init__(self, ctx: ir.Context, flags: argparse.Namespace): + self.main_entry = flags.entry + self.handle_args(flags) self.ir_module = self.init_module(ctx) @abstractmethod diff --git a/test/benchgc/src/benchgc/util.py b/test/benchgc/src/benchgc/util.py index de275f0fd..a2fea05d6 100644 --- a/test/benchgc/src/benchgc/util.py +++ b/test/benchgc/src/benchgc/util.py @@ -30,6 +30,7 @@ ERROR_OUTPUT_VERBOSE = 4 # + print all error data points if failed OUTPUT_VERBOSE = 5 # + print all result including passed tensor INPUT_VERBOSE = 6 # + print input torch tensors +PIPELINE_VERBOSE = 7 # + print ir when running pipeline """ acc | acc | elems | value_range | worst case From e05d5f0d8778867fec66029f4a9b20f173867abe Mon Sep 17 00:00:00 2001 From: "Xu, Rui" Date: Tue, 27 Aug 2024 01:52:38 -0700 Subject: [PATCH 11/38] fix format --- test/benchgc/src/benchgc/pattern/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/benchgc/src/benchgc/pattern/__init__.py b/test/benchgc/src/benchgc/pattern/__init__.py index c9c035202..b0f5cbce0 100644 --- a/test/benchgc/src/benchgc/pattern/__init__.py +++ b/test/benchgc/src/benchgc/pattern/__init__.py @@ -19,7 +19,8 @@ __all__ = ["Pattern", "MLP", "get_pattern_clz"] + def get_pattern_clz(name: str): """Function getting pattern class by name.""" clz = {"mlp": MLP}[name] - return clz \ No newline at end of file + return clz From e96d310a6850de9bbe229cc4f0f1b13bc4e51392 Mon Sep 17 00:00:00 2001 From: "Xu, Rui" Date: Tue, 27 Aug 2024 02:15:36 -0700 Subject: [PATCH 12/38] fix --- test/benchgc/src/benchgc/mlir/module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/benchgc/src/benchgc/mlir/module.py b/test/benchgc/src/benchgc/mlir/module.py index 4633c1f25..69dfd9a90 100644 --- a/test/benchgc/src/benchgc/mlir/module.py +++ b/test/benchgc/src/benchgc/mlir/module.py @@ -34,7 +34,7 @@ def init_module( module = ir.Module.create() with ir.InsertionPoint(module.body): f = func.FuncOp( - name=get_entry_name(), + name=entry_name, type=ir.FunctionType.get( inputs=[x.get_mlir_type(ctx) for x in inputs], results=[x.get_mlir_type(ctx) for x in outputs], From 44d591df137eebf3a860d7d5f76e8a0eaa0bd7fb Mon Sep 17 00:00:00 2001 From: "Xu, Rui" Date: Tue, 27 Aug 2024 06:01:36 -0700 Subject: [PATCH 13/38] add example --- test/benchgc/README.md | 71 ++++++++++++++++++++++++---- test/benchgc/src/benchgc/__main__.py | 1 - 2 files changed, 61 insertions(+), 11 deletions(-) diff --git a/test/benchgc/README.md b/test/benchgc/README.md index ab63cb156..ca719e75e 100644 --- a/test/benchgc/README.md +++ b/test/benchgc/README.md @@ -27,7 +27,7 @@ python -m pip install test/benchgc/dist/benchgc-*.whl ``` python -m benchgc [OPTIONS] --mode [MODE] --driver [DRIVER] --case [CASE] ``` -## Flags +## Common Options ### --mode [str] * C : correctness testing (by default) * P : performance testing @@ -42,6 +42,9 @@ python -m benchgc [OPTIONS] --mode [MODE] --driver [DRIVER] --case [CASE] * if driver=pattern, please provide the pre-defined pattern name, such as mlp here * if driver is a dialect name, please provide the detail op name to start a single op test +### --entry [str] +* the entry name of the kernel of input mlir or generated mlir + ### --seed [int] * set the seed to generate the test data and reprodce the test @@ -101,15 +104,10 @@ module { | Norm check | N | threshold | | Benchdnn driver | D | driver_name:dtype:case | -### --pattern - - - - -## Perfermance testing flags +## Bench Options ### --bench_kind [str] -* py -* wrapper +* py : use the MLIR Python API to invoke the kernel and use Python to calculate the time cost +* wrapper : modify MLIR by wrapping the kernel into a new method and calling the `nanoTime()` method before and after calling the kernel. Finally, calculate the difference as the time cost ### --warm_up [int] * warm-up times of the execution @@ -117,7 +115,16 @@ module { ### --repeat * repeat times of the execution -### Example +## pattern options +Each pattern has its own unique options. +### mlp +* `--batch_size`: the input +* `--hidden_size_list`: hidden_sizes of mlp, example: 32x16x64 +* `--has_bias`: if the matmul op has bias, example: 1x0 +* `--act_type`: choices=["noop", "relu", "sigmoid"] +* `--dtype`: choices=["bf16", "f32"] + +## Example ### Correctness testing example ``` # single add op test @@ -275,4 +282,48 @@ p2p check: threshold: 0.0000000 (1, 0): ref: 25.1690636 res: 25.1690636 abs_diff: 0.0000000 rel_diff: 0.0000000 (1, 1): ref: -7.8600063 res: -7.8600044 abs_diff: 0.0000019 rel_diff: 0.0000002 FAIL: linalg.matmul_transpose_b +``` + +### Perf testing example +``` +python3 -m benchgc --verbose 1 --mode P --driver linalg --case add --md 0:4x5xf32 --md 1:4x5xf32 --md 2:4x5xf32 + +module { + func.func @entry(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5xf32>) -> tensor<4x5xf32> attributes {llvm.emit_c_interface} { + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<4x5xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<4x5xf32>) -> tensor<4x5xf32> + %2 = linalg.add ins(%arg0, %arg1 : tensor<4x5xf32>, tensor<4x5xf32>) outs(%1 : tensor<4x5xf32>) -> tensor<4x5xf32> + return %2 : tensor<4x5xf32> + } +} + +===========bench result=========== +{ + "args": { + "mode": "P", + "driver": "linalg", + "case": "add", + "md": [ + "0:4x5xf32", + "1:4x5xf32", + "2:4x5xf32" + ], + "fill": [], + "cmp": [], + "seed": 0, + "verbose": 1, + "entry": "entry", + "cast": "cast_signed", + "dimension": null, + "dimensions": null, + "dilations": null, + "strides": null, + "bench_kind": "py", + "warm_up": 100, + "repeat": 100 + }, + "compile_cost(ms)": 33.73148664832115, + "execute_cost(ms)": 0.1422157883644104 +} ``` \ No newline at end of file diff --git a/test/benchgc/src/benchgc/__main__.py b/test/benchgc/src/benchgc/__main__.py index 6f0edcfa5..be7418785 100644 --- a/test/benchgc/src/benchgc/__main__.py +++ b/test/benchgc/src/benchgc/__main__.py @@ -177,7 +177,6 @@ def add_bench_options(parser: argparse.ArgumentParser): ) parser.add_argument("--warm_up", type=int, default=100) parser.add_argument("--repeat", type=int, default=100) - parser.add_argument("--entry", type=str, default="main_entry") From 79231847cea3d653a74d5ee84d3f20968d51b233 Mon Sep 17 00:00:00 2001 From: "Xu, Rui" Date: Wed, 28 Aug 2024 18:39:56 -0700 Subject: [PATCH 14/38] fix some comments --- test/benchgc/README.md | 97 +++++++++++++++++++++++-- test/benchgc/src/benchgc/__main__.py | 20 +++-- test/benchgc/src/benchgc/mlir/arg.py | 4 +- test/benchgc/src/benchgc/pattern/mlp.py | 2 +- test/benchgc/src/benchgc/util.py | 31 +++++++- 5 files changed, 136 insertions(+), 18 deletions(-) diff --git a/test/benchgc/README.md b/test/benchgc/README.md index ca719e75e..0c16b5a8b 100644 --- a/test/benchgc/README.md +++ b/test/benchgc/README.md @@ -2,12 +2,12 @@ ## Description -Benchgc is a tool used to verify the correctness and performance of graph compiler. Benchgc accepts MLIR files based on the OneDNN graph dialect as test cases and prepares test data for them. For correctness verification, Benchgc will use PyTorch as a reference for comparison. +Benchgc is a tool used to verify the correctness and performance of graph compiler. Benchgc accepts MLIR files as test cases and prepares test data for them. For correctness verification, Benchgc will use PyTorch as a reference for comparison. ## Prerequisite * python >= 3.10 * torch >= 2.2 -* pybind11 +* Enable mlir python binding, Refer to [`python/README.md`](../../python/README.md) for detail ## Build and install ``` @@ -43,13 +43,21 @@ python -m benchgc [OPTIONS] --mode [MODE] --driver [DRIVER] --case [CASE] * if driver is a dialect name, please provide the detail op name to start a single op test ### --entry [str] +* default : "entry" * the entry name of the kernel of input mlir or generated mlir ### --seed [int] * set the seed to generate the test data and reprodce the test ### --verbose [int] -* set the verbose level +* set the verbose level, default : 0 +* 0 : NO_VERBOSE +* 1 : MODULE_VERBOSE, print the module will be executed +* 2 : ARG_VERBOSE, + print arg information +* 3 : COMPARE_VERBOSE, + print threshold for comparison +* 4 : ERROR_OUTPUT_VERBOSE, + print all error data points if failed +* 5 : OUTPUT_VERBOSE, + print all result including passed tensor +* 6 : INPUT_VERBOSE, + print input torch tensors ### --md index:SHAPExTYPE * Describe the shape and data type for argument @@ -112,16 +120,16 @@ module { ### --warm_up [int] * warm-up times of the execution -### --repeat +### --repeat [int] * repeat times of the execution -## pattern options +## Pattern Options Each pattern has its own unique options. ### mlp * `--batch_size`: the input * `--hidden_size_list`: hidden_sizes of mlp, example: 32x16x64 * `--has_bias`: if the matmul op has bias, example: 1x0 -* `--act_type`: choices=["noop", "relu", "sigmoid"] +* `--act_type`: choices=["noop", "relu"] * `--dtype`: choices=["bf16", "f32"] ## Example @@ -285,6 +293,7 @@ FAIL: linalg.matmul_transpose_b ``` ### Perf testing example +* single op example ``` python3 -m benchgc --verbose 1 --mode P --driver linalg --case add --md 0:4x5xf32 --md 1:4x5xf32 --md 2:4x5xf32 @@ -326,4 +335,80 @@ module { "compile_cost(ms)": 33.73148664832115, "execute_cost(ms)": 0.1422157883644104 } +``` + +* mlir example +``` +python3 -m benchgc --mode P --verbose 1 --driver mlir --case=./test.mlir --bench_kind wrapper --warm_up 50 --repeat 200 +module { + func.func @entry(%arg0: tensor<5x6xf32>) -> tensor<5x6xf32> attributes {llvm.emit_c_interface} { + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<5x6xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<5x6xf32>) -> tensor<5x6xf32> + %2 = linalg.abs ins(%arg0 : tensor<5x6xf32>) outs(%1 : tensor<5x6xf32>) -> tensor<5x6xf32> + return %2 : tensor<5x6xf32> + } +} + +===========bench result=========== +{ + "args": { + "mode": "P", + "driver": "mlir", + "case": "/home/xurui/gc_v2/test.mlir", + "md": [], + "fill": [], + "cmp": [], + "seed": 0, + "verbose": 1, + "entry": "entry", + "ir_printing": false, + "bench_kind": "wrapper", + "warm_up": 50, + "repeat": 200 + }, + "compile_cost(ms)": 38.10911998152733, + "execute_cost(ms)": 0.077024335 +} +``` +* mlp example +``` +python3 -m benchgc --verbose 1 --mode P --driver pattern --case mlp --batch_size=32 --hidden_size_list=32x16x64 --has_bias=0x0 --act_type=noop --dtype=f32 + +module { + func.func @entry(%arg0: tensor<32x32xf32>, %arg1: tensor<32x16xf32>, %arg2: tensor<16x64xf32>) -> tensor<32x64xf32> attributes {llvm.emit_c_interface} { + %0 = tensor.empty() : tensor<32x16xf32> + %1 = linalg.matmul {cast = #linalg.type_fn} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x16xf32>) outs(%0 : tensor<32x16xf32>) -> tensor<32x16xf32> + %2 = tensor.empty() : tensor<32x64xf32> + %3 = linalg.matmul {cast = #linalg.type_fn} ins(%1, %arg2 : tensor<32x16xf32>, tensor<16x64xf32>) outs(%2 : tensor<32x64xf32>) -> tensor<32x64xf32> + return %3 : tensor<32x64xf32> + } +} + +===========bench result=========== +{ + "args": { + "mode": "P", + "driver": "pattern", + "case": "mlp", + "md": [], + "fill": [], + "cmp": [], + "seed": 0, + "verbose": 1, + "entry": "entry", + "ir_printing": false, + "bench_kind": "py", + "warm_up": 100, + "repeat": 100, + "batch_size": 32, + "hidden_size_list": "32x16x64", + "has_bias": "0x0", + "act_type": "noop", + "dtype": "f32" + }, + "compile_cost(ms)": 69.51220706105232, + "execute_cost(ms)": 0.43220914900302887 +} + ``` \ No newline at end of file diff --git a/test/benchgc/src/benchgc/__main__.py b/test/benchgc/src/benchgc/__main__.py index be7418785..3f604f0bf 100644 --- a/test/benchgc/src/benchgc/__main__.py +++ b/test/benchgc/src/benchgc/__main__.py @@ -107,7 +107,6 @@ def add_common_options(parser: argparse.ArgumentParser): benchgc.util.ERROR_OUTPUT_VERBOSE, benchgc.util.OUTPUT_VERBOSE, benchgc.util.INPUT_VERBOSE, - benchgc.util.PIPELINE_VERBOSE, ], ) @@ -119,6 +118,13 @@ def add_common_options(parser: argparse.ArgumentParser): type=str, ) + parser.add_argument( + "--ir_printing", + default=False, + help="if we need print the ir during the pass-pipeline", + type=bool, + ) + if parser.parse_known_args()[0].driver == "linalg": parser.add_argument( "--cast", @@ -293,11 +299,10 @@ def correctness_testing(flags, module, args): passes = "any(gc-cpu-pipeline)" with module.context as ctx: - ir_printing = flags.verbose >= benchgc.util.PIPELINE_VERBOSE - if ir_printing: + if flags.ir_printing: ctx.enable_multithreading(False) compiler = GraphCompiler(passes) - engine = compiler.compile_and_jit(module, ir_printing) + engine = compiler.compile_and_jit(module, flags.ir_printing) engine.invoke(flags.entry, *mlir_args) fail, mistrust = False, False @@ -331,9 +336,8 @@ def performance_testing(flags, module, args): gc_args.append(tensor) mlir_args = get_mlir_args(gc_args) - with module.context as ctx: - ir_printing = flags.verbose >= benchgc.util.PIPELINE_VERBOSE - if ir_printing: + with module.context as ctx, ir.Location.unknown(): + if flags.ir_printing: ctx.enable_multithreading(False) bench_kind = py_timeit_bench if flags.bench_kind == "py" else mlir_wrapper_bench execute_cost, compile_cost = bench_kind( @@ -341,7 +345,7 @@ def performance_testing(flags, module, args): flags.entry, "any(gc-cpu-pipeline)", mlir_args, - ir_printing, + flags.ir_printing, flags.repeat, flags.warm_up, ) diff --git a/test/benchgc/src/benchgc/mlir/arg.py b/test/benchgc/src/benchgc/mlir/arg.py index e48a1aed9..5db494c52 100644 --- a/test/benchgc/src/benchgc/mlir/arg.py +++ b/test/benchgc/src/benchgc/mlir/arg.py @@ -23,6 +23,7 @@ import gc_mlir.dialects.tensor import torch from benchgc.mlir.util import dtype_to_ctype, str_to_mlir_dtype, str_to_mlir_typed_attr +from benchgc.util import to_int_list from gc_mlir import ir @@ -95,8 +96,7 @@ def set_md(self, md: str): self.dtype = splited[-1] self.shape = [] - for dim in splited[:-1]: - self.shape.append(int(dim)) + self.shape = to_int_list(splited[:-1]) self.set_scalar() def set_scalar(self): diff --git a/test/benchgc/src/benchgc/pattern/mlp.py b/test/benchgc/src/benchgc/pattern/mlp.py index 2f4bea774..e9bf50d97 100644 --- a/test/benchgc/src/benchgc/pattern/mlp.py +++ b/test/benchgc/src/benchgc/pattern/mlp.py @@ -18,11 +18,11 @@ import argparse from benchgc.mlir.util import str_to_mlir_dtype +from benchgc.util import to_bool_list, to_int_list from gc_mlir import ir from gc_mlir.dialects import arith, func, linalg, tensor from .base import Pattern -from .util import to_bool_list, to_int_list class MLP(Pattern): diff --git a/test/benchgc/src/benchgc/util.py b/test/benchgc/src/benchgc/util.py index a2fea05d6..ae0b4228e 100644 --- a/test/benchgc/src/benchgc/util.py +++ b/test/benchgc/src/benchgc/util.py @@ -30,7 +30,6 @@ ERROR_OUTPUT_VERBOSE = 4 # + print all error data points if failed OUTPUT_VERBOSE = 5 # + print all result including passed tensor INPUT_VERBOSE = 6 # + print input torch tensors -PIPELINE_VERBOSE = 7 # + print ir when running pipeline """ acc | acc | elems | value_range | worst case @@ -333,3 +332,33 @@ def p2p( def nelem(shape: List[int]) -> int: return reduce(operator.mul, shape) + + +def to_int_list(s: str) -> List[int]: + """ + Parsing the cmd for list of int values + + Args: + s (str): int values in cmd, example: 2x3x4 + + Returns: + List[int]: int values in list, example: [2, 3, 4] + """ + if not s or len(s) == 0: + return [] + return [int(i) for i in s.strip().split("x")] + + +def to_bool_list(s: str) -> List[bool]: + """ + Parsing the cmd for list of bool values + + Args: + s (str): bools in cmd, example: 1x0x1 + + Returns: + List[bool]: bools in list, example: [True, False, True] + """ + if not s or len(s) == 0: + return [] + return [bool(int(i)) for i in s.strip().split("x")] From 8e85b80fde807bd8fb340cbbc48028567cff3018 Mon Sep 17 00:00:00 2001 From: "Xu, Rui" Date: Wed, 28 Aug 2024 18:45:57 -0700 Subject: [PATCH 15/38] fix --- test/benchgc/src/benchgc/__main__.py | 2 +- test/benchgc/src/benchgc/{mlir => }/bench.py | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename test/benchgc/src/benchgc/{mlir => }/bench.py (100%) diff --git a/test/benchgc/src/benchgc/__main__.py b/test/benchgc/src/benchgc/__main__.py index 3f604f0bf..abcf0c706 100644 --- a/test/benchgc/src/benchgc/__main__.py +++ b/test/benchgc/src/benchgc/__main__.py @@ -31,8 +31,8 @@ set_default_fill, ) from benchgc.arg.arg import Arg +from benchgc.bench import mlir_wrapper_bench, py_timeit_bench from benchgc.mlir.arg import get_mlir_args -from benchgc.mlir.bench import mlir_wrapper_bench, py_timeit_bench from benchgc.pattern import get_pattern_clz from gc_mlir import ir from gc_mlir.graph_compiler import GraphCompiler diff --git a/test/benchgc/src/benchgc/mlir/bench.py b/test/benchgc/src/benchgc/bench.py similarity index 100% rename from test/benchgc/src/benchgc/mlir/bench.py rename to test/benchgc/src/benchgc/bench.py From 56f2de6f256628214bdeade9df14e66cb68e72e8 Mon Sep 17 00:00:00 2001 From: "Xu, Rui" Date: Wed, 28 Aug 2024 19:01:06 -0700 Subject: [PATCH 16/38] fix --- test/benchgc/src/benchgc/mlir/arg.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/benchgc/src/benchgc/mlir/arg.py b/test/benchgc/src/benchgc/mlir/arg.py index 5db494c52..2fb0a9871 100644 --- a/test/benchgc/src/benchgc/mlir/arg.py +++ b/test/benchgc/src/benchgc/mlir/arg.py @@ -23,7 +23,6 @@ import gc_mlir.dialects.tensor import torch from benchgc.mlir.util import dtype_to_ctype, str_to_mlir_dtype, str_to_mlir_typed_attr -from benchgc.util import to_int_list from gc_mlir import ir @@ -96,7 +95,7 @@ def set_md(self, md: str): self.dtype = splited[-1] self.shape = [] - self.shape = to_int_list(splited[:-1]) + self.shape = [int(x) for x in splited[:-1]] self.set_scalar() def set_scalar(self): From b87b2d45ee44bbbc469940fba43d47c3aa368b6e Mon Sep 17 00:00:00 2001 From: "Xu, Rui" Date: Wed, 28 Aug 2024 20:12:54 -0700 Subject: [PATCH 17/38] add readme --- test/benchgc/README.md | 24 +++++++++++++++++++----- test/benchgc/src/benchgc/__main__.py | 3 +-- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/test/benchgc/README.md b/test/benchgc/README.md index 0c16b5a8b..f551941f9 100644 --- a/test/benchgc/README.md +++ b/test/benchgc/README.md @@ -9,20 +9,31 @@ Benchgc is a tool used to verify the correctness and performance of graph compil * torch >= 2.2 * Enable mlir python binding, Refer to [`python/README.md`](../../python/README.md) for detail -## Build and install +## Build +There are two ways for using benchgc + +* Build `.whl` and install benchgc ``` # Please execute at the top level of the project -mkdir -p build -cd build - +mkdir build && cd build cmake .. -DMLIR_DIR=$MLIR_PATH -DGC_TEST_ENABLE=ON -DGC_ENABLE_BINDINGS_PYTHON=ON -DGC_BENCH_ENABLE=ON make -j benchgc - python -m pip install test/benchgc/dist/benchgc-*.whl ``` +* Run benchgc from source code + +``` +# Please execute at the top level of the project + +mkdir build && cd build +cmake .. -DMLIR_DIR=$MLIR_PATH -DGC_TEST_ENABLE=ON -DGC_ENABLE_BINDINGS_PYTHON=ON -DGC_BENCH_ENABLE=ON +make -j GcPythonModules +export PYTHONPATH=$(pwd)/python_packages/gc_mlir_core/:$(pwd)/../test/benchgc/src/ +``` + ## Synopsis ``` python -m benchgc [OPTIONS] --mode [MODE] --driver [DRIVER] --case [CASE] @@ -59,6 +70,9 @@ python -m benchgc [OPTIONS] --mode [MODE] --driver [DRIVER] --case [CASE] * 5 : OUTPUT_VERBOSE, + print all result including passed tensor * 6 : INPUT_VERBOSE, + print input torch tensors +### --ir_printing (action=store_true) +* Print the ir during the pass-pipeline + ### --md index:SHAPExTYPE * Describe the shape and data type for argument * Not available when driver=mlir diff --git a/test/benchgc/src/benchgc/__main__.py b/test/benchgc/src/benchgc/__main__.py index abcf0c706..283fe9dc6 100644 --- a/test/benchgc/src/benchgc/__main__.py +++ b/test/benchgc/src/benchgc/__main__.py @@ -120,9 +120,8 @@ def add_common_options(parser: argparse.ArgumentParser): parser.add_argument( "--ir_printing", - default=False, + action="store_true", help="if we need print the ir during the pass-pipeline", - type=bool, ) if parser.parse_known_args()[0].driver == "linalg": From b2597b9459d0f9ffb0bed837e5e134136bd7a638 Mon Sep 17 00:00:00 2001 From: "Xu, Rui" Date: Sun, 1 Sep 2024 19:08:21 -0700 Subject: [PATCH 18/38] add mlp filling --- test/benchgc/src/benchgc/arg/__init__.py | 4 +++ test/benchgc/src/benchgc/pattern/mlp.py | 39 ++++++++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/test/benchgc/src/benchgc/arg/__init__.py b/test/benchgc/src/benchgc/arg/__init__.py index a2134af2d..312f42753 100644 --- a/test/benchgc/src/benchgc/arg/__init__.py +++ b/test/benchgc/src/benchgc/arg/__init__.py @@ -27,6 +27,7 @@ import benchgc.util import torch from benchgc.arg.arg import Arg +from benchgc.pattern import get_pattern_clz onednn_module = { "binary": binary, @@ -53,6 +54,9 @@ def set_default_fill( if flags.driver + "." + flags.case in module.op: module.default_fill(flags, arg, arglist) return + elif flags.driver == "pattern": + get_pattern_clz(flags.case).default_fill(flags, arg, arglist) + return # use N(0, 1) as default arg.fill_type = "N" arg.fill_param = ["0", "1"] diff --git a/test/benchgc/src/benchgc/pattern/mlp.py b/test/benchgc/src/benchgc/pattern/mlp.py index e9bf50d97..bcb69ea71 100644 --- a/test/benchgc/src/benchgc/pattern/mlp.py +++ b/test/benchgc/src/benchgc/pattern/mlp.py @@ -16,11 +16,14 @@ ################################################################################ import argparse +from typing import List +from benchgc.arg.arg import Arg from benchgc.mlir.util import str_to_mlir_dtype from benchgc.util import to_bool_list, to_int_list from gc_mlir import ir from gc_mlir.dialects import arith, func, linalg, tensor +from numpy import dtype from .base import Pattern @@ -144,3 +147,39 @@ def init_module(self, ctx: ir.Context) -> ir.Module: ) func.ReturnOp([data]) return module + + def default_fill( + flags: argparse.Namespace, + arg: Arg, + arglist: List[Arg], + ): + layers = len(flags.hidden_size_list.strip().split("x")) + if arg.index == 0: + # src + arg.fill_type = "D" + arg.fill_param = [ + "matmul", + "src", + arglist[0].dtype, + arglist[0].dtype, + arglist[0].dtype, + 1, + ] + elif arg.index <= layers: + # wei + arg.fill_type = "D" + arg.fill_param = [ + "matmul", + "wei", + arglist[0].dtype, + arglist[0].dtype, + arglist[0].dtype, + 1, + ] + else: + # bias + arg.fill_type = "N" + if arg.dtype in ["f32", "bf16", "f16"]: + arg.fill_param = ["-8", "8"] + else: + arg.fill_param = ["0", "8"] From 4392974f3452e74221154f33675808f7fb8782e7 Mon Sep 17 00:00:00 2001 From: "Xu, Rui" Date: Sun, 1 Sep 2024 19:33:46 -0700 Subject: [PATCH 19/38] fix mlp --- test/benchgc/src/benchgc/pattern/mlp.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/test/benchgc/src/benchgc/pattern/mlp.py b/test/benchgc/src/benchgc/pattern/mlp.py index bcb69ea71..76be9c066 100644 --- a/test/benchgc/src/benchgc/pattern/mlp.py +++ b/test/benchgc/src/benchgc/pattern/mlp.py @@ -23,7 +23,6 @@ from benchgc.util import to_bool_list, to_int_list from gc_mlir import ir from gc_mlir.dialects import arith, func, linalg, tensor -from numpy import dtype from .base import Pattern @@ -105,9 +104,13 @@ def init_module(self, ctx: ir.Context) -> ir.Module: ), ) f.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + with ir.InsertionPoint(f.add_entry_block()): data = f.entry_block.arguments[0] bias_idx = len(weights) + 1 + zero = arith.ConstantOp( + value=ir.FloatAttr.get(dtype, 0.0), result=dtype + ).result for i in range(layers): weight = f.entry_block.arguments[i + 1] if self.has_bias[i]: @@ -119,10 +122,10 @@ def init_module(self, ctx: ir.Context) -> ir.Module: self.batch_size, self.hidden_size_list[i + 1], ] - - data = linalg.matmul( - data, weight, outs=[tensor.EmptyOp(layer_out_shape, dtype)] + out = linalg.fill( + zero, outs=[tensor.EmptyOp(layer_out_shape, dtype)] ) + data = linalg.matmul(data, weight, outs=[out]) if bias: broadcast_bias = linalg.broadcast( bias, From 3566b838ded0983d40d33bbcb1648e51d9ddf1a0 Mon Sep 17 00:00:00 2001 From: "Xu, Rui" Date: Sun, 1 Sep 2024 19:35:49 -0700 Subject: [PATCH 20/38] add case --- scripts/correctness.sh | 3 +++ 1 file changed, 3 insertions(+) diff --git a/scripts/correctness.sh b/scripts/correctness.sh index c0ae008ce..509b24066 100755 --- a/scripts/correctness.sh +++ b/scripts/correctness.sh @@ -102,5 +102,8 @@ python3 -m benchgc --verbose 0 --driver mlir --case ${CASE_DIR}/reduce.mlir || F # mlir # python3 -m benchgc --verbose 0 --driver mlir --case ${CASE_DIR}/llama2.mlir || FAIL=1 +#mlp +python3 -m benchgc --verbose 1 --driver pattern --case mlp --batch_size=32 --hidden_size_list=32x16x64 --has_bias=1x1 --act_type=noop --dtype=f32 + set +e exit $FAIL \ No newline at end of file From 8deb44ccee03276255dc87a38f175e8b48919d4b Mon Sep 17 00:00:00 2001 From: "Xu, Rui" Date: Sun, 1 Sep 2024 19:45:20 -0700 Subject: [PATCH 21/38] remove old bench code --- tools/README.md | 92 -------------- tools/bench.py | 205 ------------------------------- tools/drivers.py | 232 ----------------------------------- tools/example/simple_test.py | 74 ----------- tools/main.py | 102 --------------- tools/utils.py | 191 ---------------------------- tools/workloads/test.mlir | 8 -- 7 files changed, 904 deletions(-) delete mode 100644 tools/README.md delete mode 100644 tools/bench.py delete mode 100644 tools/drivers.py delete mode 100644 tools/example/simple_test.py delete mode 100644 tools/main.py delete mode 100644 tools/utils.py delete mode 100644 tools/workloads/test.mlir diff --git a/tools/README.md b/tools/README.md deleted file mode 100644 index 931a1b906..000000000 --- a/tools/README.md +++ /dev/null @@ -1,92 +0,0 @@ -# Python Tools -## Pre-requisites -### Enable python binding -* Enable MLIR python binding, [README](https://github.com/intel/graph-compiler/blob/main/python/README.md) -### Set env -* **PYTHONPATH**=*${BUILD_DIR}*/python_packages/gc_mlir_core -* **LD_PRELOAD**=path/to/libiomp5.so - - -## Bench -The tool has two different ways to calculate the time cost, and more experiments are needed to test which one is more stable and accurate. Currently, users can choose which way to use through options -* Use the MLIR Python API to invoke the kernel and use Python to calculate the time cost -* Modify MLIR by wrapping the kernel into a new method and calling the `nanoTime()` method before and after calling the kernel. Finally, calculate the difference as the time cost -``` - func.func private @nanoTime() -> i64 attributes {llvm.emit_c_interface} - func.func public @wrapped_main(%arg0: memref<1xi64>, %arg1: tensor<128x512xbf16>, %arg2: tensor<512x256xbf16>) -> tensor<128x256xbf16> attributes {llvm.emit_c_interface} { - %0 = call @nanoTime() : () -> i64 - %1 = call @main_entry(%arg1, %arg2) : (tensor<128x512xbf16>, tensor<512x256xbf16>) -> tensor<128x256xbf16> - %2 = call @nanoTime() : () -> i64 - %3 = arith.subi %2, %0 : i64 - %c0 = arith.constant 0 : index - memref.store %3, %arg0[%c0] : memref<1xi64> - return %1 : tensor<128x256xbf16> - } -} -``` - -### Examples: -``` -# simple version -python3 ./tools/main.py --driver=load_mlir --path=./tools/workloads/test.mlir - -# complex version -python3 ./tools/main.py --type=bench --bench_kind=py --driver=load_mlir --path=./tools/workloads/test.mlir --warm_up=200 --repeat=200 --print_ir --entry=main_entry -``` - -``` -# result example -===========bench result=========== -{ - "args": { - "type": "bench", - "driver": "load_mlir", - "path": "./tools/workloads/test.mlir", - "entry": "main_entry", - "bench_kind": "py", - "print_ir": false, - "warm_up": 20, - "repeat": 100 - }, - "compile_cost(ms)": 25.58841183781624, - "execute_cost(ms)": 1.7501823976635933 -} -``` - -### Common Options -* `--driver`: the pattern to bench, currently support `mlp` and `load_mlir` -* `--bench_kind`: `py` or `wrapper`, different evaluation implementation of the benchmark -* `--warm_up`: warm-up times of the execution -* `--repeat`: repeat times of the execution -* `--print_ir`: print the ir before execution -* `--disable_results_to_params`: do not use this when using the default pipeline (gc-cpu-pipeline) - -### Driver Specific Options -* load_mlir - * `--path`: the mlir file path - * `--entry`: the name of entry func -``` -python3 ./tools/main.py --driver=load_mlir --path=./tools/workloads/test.mlir -``` - - -* mlp - * `--batch_size`: the input - * `--hidden_size_list`: hidden_sizes of mlp, example: 32x16x64 - * `--has_bias`: if the matmul op has bias, example: 1x0 - * `--act_type`: choices=["noop", "relu", "sigmoid"] - * `--dtype`: choices=["bf16", "f32"] -``` -python3 ./tools/main.py --driver=mlp --batch_size=32 --hidden_size_list=32x16x64 --has_bias=0x0 --act_type=noop --dtype=f32 - -===========bench func name: main_entry =========== -module { - func.func @main_entry(%arg0: tensor<32x32xf32>, %arg1: tensor<32x16xf32>, %arg2: tensor<16x64xf32>) -> tensor<32x64xf32> attributes {llvm.emit_c_interface} { - %0 = tensor.empty() : tensor<32x16xf32> - %1 = linalg.matmul {cast = #linalg.type_fn} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x16xf32>) outs(%0 : tensor<32x16xf32>) -> tensor<32x16xf32> - %2 = tensor.empty() : tensor<32x64xf32> - %3 = linalg.matmul {cast = #linalg.type_fn} ins(%1, %arg2 : tensor<32x16xf32>, tensor<16x64xf32>) outs(%2 : tensor<32x64xf32>) -> tensor<32x64xf32> - return %3 : tensor<32x64xf32> - } -} -``` \ No newline at end of file diff --git a/tools/bench.py b/tools/bench.py deleted file mode 100644 index 7cf06d409..000000000 --- a/tools/bench.py +++ /dev/null @@ -1,205 +0,0 @@ -################################################################################ -# Copyright (C) 2024 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions -# and limitations under the License. -# SPDX-License-Identifier: Apache-2.0 -################################################################################ - -import ctypes -import random -import timeit -from typing import List, Sequence, Tuple - -import numpy as np -from gc_mlir import ir, runtime -from gc_mlir.graph_compiler import GraphCompiler -from utils import ( - emit_benchmark_wrapped_main_func, - emit_nano_time, - get_kernel_func_from_module, -) - - -def py_timeit_bench( - ir_module: ir.Module, - entry_name: str, - pipeline: str, - mlir_args: list, - ir_printing=False, - repeat_time=100, - warm_up=20, -) -> Tuple[float, float]: - """benchmark mlir with python timeit.""" - compiler = GraphCompiler(pipeline) - compile_begin = timeit.default_timer() - engine = compiler.compile_and_jit(ir_module, ir_printing=ir_printing) - compile_cost = (timeit.default_timer() - compile_begin) * 1000 - - # Copied from execution_engine.py so that the cost of cast does not affect perf result. - func = engine.lookup(entry_name) - packed_args = (ctypes.c_void_p * len(mlir_args))() - for argNum in range(len(mlir_args)): - packed_args[argNum] = ctypes.cast(mlir_args[argNum], ctypes.c_void_p) - - def run_bench(func, arg): - func(arg) - - timeit.timeit(lambda: run_bench(func, packed_args), number=warm_up) - total_time = timeit.timeit(lambda: run_bench(func, packed_args), number=repeat_time) - execute_cost = total_time * 1000 / repeat_time - return (execute_cost, compile_cost) - - -def mlir_wrapper_bench( - ir_module: ir.Module, - entry_name: str, - pipeline: str, - mlir_args: list, - ir_printing=False, - repeat_time=100, - warm_up=20, -) -> Tuple[float, float]: - """benchmark mlir with a wrapper func.""" - kernel_func = get_kernel_func_from_module(ir_module, entry_name) - wrapper_module = ir_module - with ir.InsertionPoint(wrapper_module.body): - emit_benchmark_wrapped_main_func(kernel_func, emit_nano_time()) - compiler = GraphCompiler(pipeline) - compile_begin = timeit.default_timer() - engine = compiler.compile_and_jit(wrapper_module, ir_printing=ir_printing) - compile_cost = (timeit.default_timer() - compile_begin) * 1000 - - np_timers_ns = np.array([0], dtype=np.int64) - time_arg = ctypes.pointer( - ctypes.pointer(runtime.get_ranked_memref_descriptor(np_timers_ns)) - ) - total_time = 0 - ns_to_ms_scale = 1e-6 - def run(engine_invoke, bench_func_name, *mlir_args): - engine_invoke(bench_func_name, *mlir_args) - - for i in range(repeat_time + warm_up): - run(engine.invoke, "wrapped_main", time_arg, *mlir_args) - if i >= warm_up: - total_time += int(np_timers_ns[0]) * ns_to_ms_scale - execute_cost = total_time / repeat_time - return (execute_cost, compile_cost) - - -# for test -def fake_bench( - ir_module: ir.Module = None, - entry_name: str = None, - pipeline: str = None, - mlir_args: list = None, - ir_printing=False, - repeat_time=100, - warm_up=20, -) -> Tuple[float, float]: - """genrate fake benchmark result.""" - execute_cost = float(random.randint(1, 100)) - compile_cost = float(random.randint(1, 100)) - return (execute_cost, compile_cost) - - -def batch_py_timeit_bench( - ir_modules: List[ir.Module], - entry_name: str, - pipeline: str, - mlir_args: list, - ir_printing=False, - repeat_time=5, - warm_up=2, -) -> List[Tuple[float, float]]: - """benchmark a batch of mlir with python timeit.""" - compiler = GraphCompiler(pipeline) - funcs = [] - compile_costs = [] - for m in ir_modules: - compile_begin = timeit.default_timer() - engine = compiler.compile_and_jit(m, ir_printing=ir_printing) - compile_cost = (timeit.default_timer() - compile_begin) * 1000 - compile_costs.append(compile_cost) - funcs.append(engine.lookup(entry_name)) - - # Copied from execution_engine.py so that the cost of cast does not affect perf result. - packed_args = (ctypes.c_void_p * len(mlir_args))() - for argNum in range(len(mlir_args)): - packed_args[argNum] = ctypes.cast(mlir_args[argNum], ctypes.c_void_p) - - def run_bench(func, arg): - func(arg) - - for func in funcs: - timeit.timeit(lambda: run_bench(func, packed_args), number=warm_up) - - execute_costs = [] - for func in funcs: - total_time = timeit.timeit( - lambda: run_bench(func, packed_args), number=repeat_time - ) - execute_cost = total_time * 1000 / repeat_time - execute_costs.append(execute_cost) - return list(zip(compile_costs, execute_costs)) - - -def batch_mlir_wrapper_bench( - ir_modules: ir.Module, - entry_name: str, - pipeline: str, - mlir_args: list, - ir_printing=False, - repeat_time=5, - warm_up=2, -) -> Tuple[float, float]: - """benchmark a batch of mlir with wrapper func.""" - compiler = GraphCompiler(pipeline) - - engine_invokes = [] - compile_costs = [] - for m in ir_modules: - kernel_func = get_kernel_func_from_module(m, entry_name) - wrapper_module = m - with ir.InsertionPoint(wrapper_module.body): - emit_benchmark_wrapped_main_func(kernel_func, emit_nano_time()) - compile_begin = timeit.default_timer() - engine = compiler.compile_and_jit(wrapper_module, ir_printing=ir_printing) - compile_cost = (timeit.default_timer() - compile_begin) * 1000 - compile_costs.append(compile_cost) - engine_invokes.append(engine.invoke) - - np_timers_ns = np.array([0], dtype=np.int64) - time_arg = ctypes.pointer( - ctypes.pointer(runtime.get_ranked_memref_descriptor(np_timers_ns)) - ) - total_time = 0 - ns_to_ms_scale = 1e-6 - - def run(engine_invoke, bench_func_name, *mlir_args): - engine_invoke(bench_func_name, *mlir_args) - - for engine_invoke in engine_invokes: - for _ in range(warm_up): - run(engine_invoke, "wrapped_main", time_arg, *mlir_args) - - execute_costs = [] - for engine_invoke in engine_invokes: - total_time = 0 - for _ in range(repeat_time): - run(engine_invoke, "wrapped_main", time_arg, *mlir_args) - total_time += int(np_timers_ns[0]) * ns_to_ms_scale - - execute_cost = total_time / repeat_time - execute_costs.append(execute_cost) - - return list(zip(compile_costs, execute_costs)) diff --git a/tools/drivers.py b/tools/drivers.py deleted file mode 100644 index a9bdc95d0..000000000 --- a/tools/drivers.py +++ /dev/null @@ -1,232 +0,0 @@ -################################################################################ -# Copyright (C) 2024 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions -# and limitations under the License. -# SPDX-License-Identifier: Apache-2.0 -################################################################################ - -import argparse -from abc import ABC, abstractmethod -from typing import List - -import numpy as np -from gc_mlir import ir -from gc_mlir.dialects import arith, func, linalg, tensor -from gc_mlir.ir import BF16Type, FloatAttr -from utils import ( - STR_TO_MLIR_TYPE, - get_default_passes, - get_kernel_func_from_module, - make_mlir_ndarray, - to_bool_list, - to_int_list, -) - - -class Driver(ABC): - """Abstract class for driver.""" - - @staticmethod - @abstractmethod - def add_args(parser: argparse.ArgumentParser): - """Add arguments to parser""" - pass - - @abstractmethod - def handle_args(self, args: argparse.Namespace): - """Get and handle the args""" - pass - - def __init__(self, ctx: ir.Context, args: argparse.Namespace): - self.main_entry = "main_entry" - self.handle_args(args) - self.ir_module = self.init_module(ctx) - - @abstractmethod - def init_module(self, ctx: ir.Context) -> ir.Module: - """Create MLIR moudule by args""" - pass - - @abstractmethod - def prepare_np_args(self, disable_results_to_params: False) -> List[np.ndarray]: - """Create numpy arg for entry function""" - pass - - def get_passes(self) -> str: - """Get pass pipeline""" - return get_default_passes() - - -class LoadMLIR(Driver): - @staticmethod - def add_args(parser: argparse.ArgumentParser): - parser.add_argument("--path", type=str, required=True) - parser.add_argument("--entry", type=str, default="main_entry") - - def handle_args(self, args: argparse.Namespace): - self.path = args.path - self.main_entry = args.entry - - def _get_mlir(self): - with open(self.path, "r") as file: - content = file.read() - return content - - def init_module(self, ctx: ir.Context) -> ir.Module: - module = ir.Module.parse(self._get_mlir(), ctx) - bench_func = get_kernel_func_from_module(module, self.main_entry) - bench_func.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() - return module - - def prepare_np_args(self, disable_results_to_params: False) -> List[np.ndarray]: - bench_func = get_kernel_func_from_module(self.ir_module, self.main_entry) - np_args = [] - for arg in bench_func.arguments: - np_args.append(make_mlir_ndarray(arg.type)) - - if not disable_results_to_params: - for res in bench_func.type.results: - np_args.append(make_mlir_ndarray(res)) - - return np_args - -class MLP(Driver): - @staticmethod - def add_args(parser: argparse.ArgumentParser): - parser.add_argument("--batch_size", type=int, default=1) - parser.add_argument("--hidden_size_list", type=str, default="") - parser.add_argument("--has_bias", required=False, type=str) - parser.add_argument( - "--act_type", type=str, choices=["noop", "relu", "sigmoid"], default="noop" - ) - parser.add_argument( - "--dtype", - type=str, - choices=[ - "f32", - "bf16", - ], - default="f32", - ) - - def handle_args(self, args: argparse.Namespace): - self.batch_size = args.batch_size - assert self.batch_size > 0, "batch size should be greater than 0" - - self.hidden_size_list = to_int_list(args.hidden_size_list) - layers = len(self.hidden_size_list) - 1 - assert layers >= 1, "hidden_size_list should have at least 2 elements" - - self.has_bias = ( - [False] * layers if args.has_bias is None else to_bool_list(args.has_bias) - ) - - assert ( - len(self.has_bias) == layers - ), "has_bias should have the same length as hidden_size_list" - - self.act_type = args.act_type - self.dtype = args.dtype - - def init_module(self, ctx: ir.Context) -> ir.Module: - with ctx, ir.Location.unknown(): - layers = len(self.hidden_size_list) - 1 - module = ir.Module.create() - dtype = STR_TO_MLIR_TYPE(self.dtype, ctx) - src = ir.RankedTensorType.get( - [self.batch_size, self.hidden_size_list[0]], dtype - ) - weights = [] - bias = [] - for i in range(layers): - weights.append( - ir.RankedTensorType.get( - [ - self.hidden_size_list[i], - self.hidden_size_list[i + 1], - ], - dtype, - ) - ) - if self.has_bias[i]: - bias.append( - ir.RankedTensorType.get([self.hidden_size_list[i + 1]], dtype) - ) - result = ir.RankedTensorType.get( - [ - self.batch_size, - self.hidden_size_list[-1], - ], - dtype, - ) - with ir.InsertionPoint(module.body): - f = func.FuncOp( - name=self.main_entry, - type=ir.FunctionType.get( - inputs=[src] + weights + bias, results=[result] - ), - ) - f.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() - with ir.InsertionPoint(f.add_entry_block()): - data = f.entry_block.arguments[0] - bias_idx = len(weights) + 1 - for i in range(layers): - weight = f.entry_block.arguments[i + 1] - if self.has_bias[i]: - bias = f.entry_block.arguments[bias_idx] - bias_idx += 1 - else: - bias = None - layer_out_shape = [ - self.batch_size, - self.hidden_size_list[i + 1], - ] - - data = linalg.matmul( - data, weight, outs=[tensor.EmptyOp(layer_out_shape, dtype)] - ) - if bias: - broadcast_bias = linalg.broadcast( - bias, - outs=[tensor.EmptyOp(layer_out_shape, dtype)], - dimensions=[0], - ) - data = linalg.add( - data, - broadcast_bias, - outs=[tensor.EmptyOp(layer_out_shape, dtype)], - ) - - if self.act_type == "relu": - element = FloatAttr.get(dtype, 0) - tensor_type = ir.RankedTensorType.get( - layer_out_shape, dtype - ) - attr = ir.DenseElementsAttr.get_splat(tensor_type, element) - cst = arith.ConstantOp(tensor_type, attr) - data = linalg.max( - data, cst, outs=[tensor.EmptyOp(layer_out_shape, dtype)] - ) - func.ReturnOp([data]) - return module - - def prepare_np_args(self, disable_results_to_params: False) -> List[np.ndarray]: - bench_func = get_kernel_func_from_module(self.ir_module, self.main_entry) - np_args = [] - for arg in bench_func.arguments: - np_args.append(make_mlir_ndarray(arg.type)) - - if not disable_results_to_params: - for res in bench_func.type.results: - np_args.append(make_mlir_ndarray(res)) - return np_args diff --git a/tools/example/simple_test.py b/tools/example/simple_test.py deleted file mode 100644 index 81baa9085..000000000 --- a/tools/example/simple_test.py +++ /dev/null @@ -1,74 +0,0 @@ -################################################################################ -# Copyright (C) 2024 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions -# and limitations under the License. -# SPDX-License-Identifier: Apache-2.0 -################################################################################ - -import os -import sys - -import numpy as np -from gc_mlir import ir -from gc_mlir.graph_compiler import GraphCompiler -from numpy.testing import assert_allclose - -project_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -if project_dir not in sys.path: - sys.path.insert(0, project_dir) - -import ml_dtypes -import torch -from utils import get_mlir_args - -# an example of simple validation -if __name__ == "__main__": - with ir.Context() as ctx: - ctx.enable_multithreading(False) - module = ir.Module.parse( - """ - module { - func.func @main_entry(%arg0: tensor<10x10xbf16>, %arg1: tensor<10x10xbf16>) -> tensor<10x10xbf16> attributes {llvm.emit_c_interface} { - %cst = arith.constant 0.000000e+00 : bf16 - %0 = tensor.empty() : tensor<10x10xbf16> - %1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<10x10xbf16>) -> tensor<10x10xbf16> - %2 = linalg.matmul ins(%arg0, %arg1 : tensor<10x10xbf16>, tensor<10x10xbf16>) outs(%1 : tensor<10x10xbf16>) -> tensor<10x10xbf16> - return %2 : tensor<10x10xbf16> - } - } - """ - ) - torch_arg0 = torch.full((10, 10), 1.0, dtype=torch.bfloat16) - torch_arg1 = torch.full((10, 10), 1.0, dtype=torch.bfloat16) - ref_res = torch.matmul(torch_arg0, torch_arg1) - - np_arg0 = torch_arg0.view(dtype=torch.uint16).numpy().view(ml_dtypes.bfloat16) - np_arg1 = torch_arg1.view(dtype=torch.uint16).numpy().view(ml_dtypes.bfloat16) - gc_res = np.zeros((10, 10), dtype=ml_dtypes.bfloat16) - - entry = "main_entry" - mlir_args = get_mlir_args(module, entry, [np_arg0, np_arg1, gc_res]) - passes = "any(gc-cpu-pipeline)" - - # just run - compiler = GraphCompiler(passes) - engine = compiler.compile_and_jit(module, ir_printing=True) - engine.invoke(entry, *mlir_args) - - print(gc_res) - assert_allclose( - gc_res.astype(np.float32), - ref_res.to(torch.float32).numpy(), - rtol=1e-5, - atol=0, - ) diff --git a/tools/main.py b/tools/main.py deleted file mode 100644 index 9c08c6913..000000000 --- a/tools/main.py +++ /dev/null @@ -1,102 +0,0 @@ -################################################################################ -# Copyright (C) 2024 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions -# and limitations under the License. -# SPDX-License-Identifier: Apache-2.0 -################################################################################ - -import argparse -import json -import numpy as np -from bench import ( - mlir_wrapper_bench, - py_timeit_bench, -) -from drivers import MLP, LoadMLIR -from gc_mlir import ir -from utils import get_mlir_args - - -def get_driver_clz(diver_str: str): - """Function getting driver class by name.""" - clz = {"mlp": MLP, "load_mlir": LoadMLIR}[diver_str] - return clz - - -def add_driver_args(arg_parser: argparse.ArgumentParser): - """Function adding args for different driver.""" - driver = arg_parser.parse_known_args()[0].driver - get_driver_clz(driver).add_args(arg_parser) - - -def do_bench(args: argparse.Namespace): - """Function benching mlir""" - with ir.Context() as ctx, ir.Location.unknown(): - driver_clz = get_driver_clz(args.driver) - driver = driver_clz(ctx, args) - if args.print_ir: - ctx.enable_multithreading(False) - np_args = driver.prepare_np_args(args.disable_results_to_params) - - # TODO need data filling - # for test, fill all data with 1 - for np_arg in np_args: - np.ndarray.fill(np_arg, 1) - - mlir_args = get_mlir_args( - driver.ir_module, driver.main_entry, np_args, args.disable_results_to_params - ) - - print("===========bench func name: ", driver.main_entry, "===========") - print(driver.ir_module) - bench_kind = py_timeit_bench if args.bench_kind == "py" else mlir_wrapper_bench - execute_cost, compile_cost = bench_kind( - driver.ir_module, - driver.main_entry, - driver.get_passes(), - mlir_args, - args.print_ir, - args.repeat, - args.warm_up, - ) - print("===========bench result===========") - json_res = json.dumps( - { - "args": vars(args), - "compile_cost(ms)": compile_cost, - "execute_cost(ms)": execute_cost, - }, - indent=4, - ) - print(json_res) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--type", type=str, choices=["bench"], default="bench") - parser.add_argument( - "--driver", type=str, choices=["load_mlir", "mlp"], required=True - ) - add_driver_args(parser) - parser.add_argument( - "--bench_kind", type=str, choices=["py", "wrapper"], default="py" - ) - parser.add_argument("-p", "--print_ir", action="store_true") - parser.add_argument( - "--disable_results_to_params", action="store_true", default=False - ) - - parser.add_argument("--warm_up", type=int, default=100) - parser.add_argument("--repeat", type=int, default=100) - - do_bench(parser.parse_args()) diff --git a/tools/utils.py b/tools/utils.py deleted file mode 100644 index fc01fd208..000000000 --- a/tools/utils.py +++ /dev/null @@ -1,191 +0,0 @@ -################################################################################ -# Copyright (C) 2024 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions -# and limitations under the License. -# SPDX-License-Identifier: Apache-2.0 -################################################################################ - -import ctypes -from typing import List - -import ml_dtypes -import numpy as np -from gc_mlir import ir -from gc_mlir.dialects import arith, func, memref -from gc_mlir.runtime.np_to_memref import ( - BF16, - get_ranked_memref_descriptor, - make_nd_memref_descriptor, -) - -MLIR_TYPE_TO_NUMPY_TYPE = { - "bf16": ml_dtypes.bfloat16, - "f32": np.float32, - "f64": np.float64, - "i8": np.int8, - "i32": np.int32, - "i64": np.int64, -} - -MLIR_TYPE_TO_C_TYPE = { - "f32": ctypes.c_float, - "f64": ctypes.c_double, - "i32": ctypes.c_int, - "i8": ctypes.c_byte, - "bf16": BF16, -} - - -def STR_TO_MLIR_TYPE(type: str, ctx: ir.Context): - type_map = { - "f32": ir.F32Type.get(ctx), - "f64": ir.F64Type.get(ctx), - "bf16": ir.BF16Type.get(ctx), - "i32": ir.IntegerType.get_signed(32, ctx), - "i8": ir.IntegerType.get_signed(8, ctx), - } - return type_map[type] - - -def emit_nano_time() -> func.FuncOp: - """Emit a nanoTime function that returns the current time in nanoseconds.""" - nanoTime = func.FuncOp( - "nanoTime", ([], [ir.IntegerType.get_signless(64)]), visibility="private" - ) - nanoTime.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() - return nanoTime - - -def emit_benchmark_wrapped_main_func( - kernel_func: func.FuncOp, timer_func: func.FuncOp -) -> func.FuncOp: - """Emit a wrapped main function that calls the kernel function and records the time taken.""" - memref_of_i64_type = ir.MemRefType.get([1], ir.IntegerType.get_signless(64)) - wrapped_func_name = "wrapped_main" - assert wrapped_func_name != str( - kernel_func.name - ), "wrapped function name should be different from kernel function name" - wrapped_func = func.FuncOp( - wrapped_func_name, - ([memref_of_i64_type] + kernel_func.arguments.types, kernel_func.type.results), - visibility="public", - ) - wrapped_func.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() - with ir.InsertionPoint(wrapped_func.add_entry_block()): - timer_buffer = wrapped_func.arguments[0] - start = func.CallOp(timer_func, []) - call_op = func.CallOp( - kernel_func, - list(wrapped_func.arguments[1:]), - ) - end = func.CallOp(timer_func, []) - time_taken = arith.SubIOp(end, start) - zero = arith.ConstantOp.create_index(0) - memref.StoreOp(time_taken, timer_buffer, [zero]) - func.ReturnOp(call_op.results) - return wrapped_func - - -def get_mlir_args( - module: ir.Module, - entry: str, - np_args: List[np.ndarray], - disable_results_to_params=False, -): - """Convert numpy arrays to MLIR args and return a list of pointers to them""" - f = get_kernel_func_from_module(module, entry) - compiled_func_args = [] - if disable_results_to_params: - assert len(np_args) == len(f.arguments), "input args mismatch" - for res in f.type.results: - compiled_func_args.append( - ctypes.pointer( - ctypes.pointer( - make_nd_memref_descriptor( - len(res.shape), MLIR_TYPE_TO_C_TYPE[str(res.element_type)] - )() - ) - ) - ) - for arg in np_args: - compiled_func_args.append( - ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(arg))) - ) - return compiled_func_args - - -def make_mlir_ndarray(mlir_type): - """create numpy ndarray from mlir type""" - return np.zeros( - mlir_type.shape, MLIR_TYPE_TO_NUMPY_TYPE[str(mlir_type.element_type)] - ) - - -def get_kernel_func_from_module( - module: ir.Module, func_name: str = "main_entry" -) -> func.FuncOp: - """Get the func op by the name from a module""" - assert ( - len(module.operation.regions) == 1 - ), "Expected kernel module to have only one region" - assert ( - len(module.operation.regions[0].blocks) == 1 - ), "Expected kernel module to have only one block" - for f in module.operation.regions[0].blocks[0].operations: - if type(f) is func.FuncOp and str(f.name).strip('"') == func_name: - return f - raise ValueError("can not find the entry function") - - -def get_default_passes(): - passes = """ - any(gc-cpu-pipeline) - """ - return passes - - -def to_int_list(s: str) -> List[int]: - """ - Parsing the cmd for list of int values - - Args: - s (str): int values in cmd, example: 2x3x4 - - Returns: - List[int]: int values in list, example: [2, 3, 4] - """ - if not s or len(s) == 0: - return [] - return [int(i) for i in s.strip().split("x")] - - -def to_bool_list(s: str) -> List[bool]: - """ - Parsing the cmd for list of bool values - - Args: - s (str): bools in cmd, example: 1x0x1 - - Returns: - List[bool]: bools in list, example: [True, False, True] - """ - if not s or len(s) == 0: - return [] - return [bool(int(i)) for i in s.strip().split("x")] - - -def load_mlir_from_path(path: str) -> str: - """Load MLIR content from path""" - with open(path, "r") as file: - content = file.read() - return content diff --git a/tools/workloads/test.mlir b/tools/workloads/test.mlir deleted file mode 100644 index 9170ec61a..000000000 --- a/tools/workloads/test.mlir +++ /dev/null @@ -1,8 +0,0 @@ - -func.func @main_entry(%arg0: tensor<128x512xbf16>, %arg1: tensor<512x256xbf16>) -> tensor<128x256xbf16> attributes {llvm.emit_c_interface} { - %cst = arith.constant 0.000000e+00 : bf16 - %0 = tensor.empty() : tensor<128x256xbf16> - %1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<128x256xbf16>) -> tensor<128x256xbf16> - %2 = linalg.matmul ins(%arg0, %arg1 : tensor<128x512xbf16>, tensor<512x256xbf16>) outs(%1 : tensor<128x256xbf16>) -> tensor<128x256xbf16> - return %2 : tensor<128x256xbf16> -} \ No newline at end of file From a0641e9f75d6ddbf2cf565a2cff5f3e81be0de37 Mon Sep 17 00:00:00 2001 From: "Xu, Rui" Date: Sun, 1 Sep 2024 20:08:11 -0700 Subject: [PATCH 22/38] update readme --- test/benchgc/README.md | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/test/benchgc/README.md b/test/benchgc/README.md index f551941f9..9f18cc398 100644 --- a/test/benchgc/README.md +++ b/test/benchgc/README.md @@ -337,6 +337,7 @@ module { "seed": 0, "verbose": 1, "entry": "entry", + "ir_printing": false, "cast": "cast_signed", "dimension": null, "dimensions": null, @@ -346,21 +347,21 @@ module { "warm_up": 100, "repeat": 100 }, - "compile_cost(ms)": 33.73148664832115, - "execute_cost(ms)": 0.1422157883644104 + "compile_cost(ms)": 37.72595152258873, + "execute_cost(ms)": 0.00022314488887786865 } ``` * mlir example ``` python3 -m benchgc --mode P --verbose 1 --driver mlir --case=./test.mlir --bench_kind wrapper --warm_up 50 --repeat 200 -module { - func.func @entry(%arg0: tensor<5x6xf32>) -> tensor<5x6xf32> attributes {llvm.emit_c_interface} { +\module { + func.func @entry(%arg0: tensor<512x128xf32>) -> tensor<512x128xf32> attributes {llvm.emit_c_interface} { %cst = arith.constant 0.000000e+00 : f32 - %0 = tensor.empty() : tensor<5x6xf32> - %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<5x6xf32>) -> tensor<5x6xf32> - %2 = linalg.abs ins(%arg0 : tensor<5x6xf32>) outs(%1 : tensor<5x6xf32>) -> tensor<5x6xf32> - return %2 : tensor<5x6xf32> + %0 = tensor.empty() : tensor<512x128xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<512x128xf32>) -> tensor<512x128xf32> + %2 = linalg.abs ins(%arg0 : tensor<512x128xf32>) outs(%1 : tensor<512x128xf32>) -> tensor<512x128xf32> + return %2 : tensor<512x128xf32> } } @@ -381,8 +382,8 @@ module { "warm_up": 50, "repeat": 200 }, - "compile_cost(ms)": 38.10911998152733, - "execute_cost(ms)": 0.077024335 + "compile_cost(ms)": 70.6995539367199, + "execute_cost(ms)": 0.029325044999999984 } ``` * mlp example @@ -421,8 +422,8 @@ module { "act_type": "noop", "dtype": "f32" }, - "compile_cost(ms)": 69.51220706105232, - "execute_cost(ms)": 0.43220914900302887 + "compile_cost(ms)": 109.86808314919472, + "execute_cost(ms)": 0.02944003790616989 } ``` \ No newline at end of file From 2448b7648b086b8a350577175a96b06d180a5314 Mon Sep 17 00:00:00 2001 From: "Xu, Rui" Date: Sun, 1 Sep 2024 22:55:18 -0700 Subject: [PATCH 23/38] add attch dlti --- test/benchgc/src/benchgc/__main__.py | 2 ++ test/benchgc/src/benchgc/mlir/util.py | 38 +++++++++++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/test/benchgc/src/benchgc/__main__.py b/test/benchgc/src/benchgc/__main__.py index 283fe9dc6..ed101f46b 100644 --- a/test/benchgc/src/benchgc/__main__.py +++ b/test/benchgc/src/benchgc/__main__.py @@ -264,6 +264,8 @@ def get_module_and_args(flags): for arg in args: arg.print_verbose(flags.verbose) + benchgc.mlir.util.attch_dlti(module) + if flags.verbose >= benchgc.util.MODULE_VERBOSE: print(module) return module, args diff --git a/test/benchgc/src/benchgc/mlir/util.py b/test/benchgc/src/benchgc/mlir/util.py index faaf4daef..287160243 100644 --- a/test/benchgc/src/benchgc/mlir/util.py +++ b/test/benchgc/src/benchgc/mlir/util.py @@ -15,8 +15,10 @@ ################################################################################ import ctypes +import os from typing import Any, List +import cpuinfo import torch from gc_mlir import ir from gc_mlir.dialects import arith, func, memref @@ -152,3 +154,39 @@ def get_kernel_func_from_module( if type(f) is func.FuncOp and str(f.name).strip('"') == func_name: return f raise ValueError("can not find the entry function") + + +def attch_dlti(module: ir.Module): + info = cpuinfo.get_cpu_info() + l1_data_cache_size = info.get("l1_data_cache_size") + l2_cache_size = info.get("l2_cache_size") + l3_cache_size = info.get("l3_cache_size") + if "GC_NUM_THREADS" not in os.environ: + print("GC_NUM_THREADS is not found, using 1 as default") + num_threads = os.environ.get("GC_NUM_THREADS", 1) + flags = info.get("flags") + max_vector_width = 64 + for flag in flags: + if "avx512f" == flag: + max_vector_width = max(512, max_vector_width) + elif "avx2" == flag or "avx" == flag: + max_vector_width = max(256, max_vector_width) + elif "sse" in flag: + max_vector_width = max(128, max_vector_width) + + dlti_template = f""" + module attributes {{ + dlti.target_system_spec = #dlti.target_system_spec< + "CPU": #dlti.target_device_spec< + #dlti.dl_entry<"L1_cache_size_in_bytes", {l1_data_cache_size} : ui32>, + #dlti.dl_entry<"L2_cache_size_in_bytes", {l2_cache_size} : ui64>, + #dlti.dl_entry<"L3_cache_size_in_bytes", {l3_cache_size} : ui64>, + #dlti.dl_entry<"num_threads", {num_threads} : i32>, + #dlti.dl_entry<"max_vector_width", {max_vector_width} : i64>> + >}} {{}} + """ + with module.context: + template_module = ir.Module.parse(dlti_template) + module.operation.attributes["dlti.target_system_spec"] = ( + template_module.operation.attributes["dlti.target_system_spec"] + ) \ No newline at end of file From d614c4006c6b69e20903f3a43bd8a91509d6fc57 Mon Sep 17 00:00:00 2001 From: "Xu, Rui" Date: Sun, 1 Sep 2024 23:00:55 -0700 Subject: [PATCH 24/38] skip attach dlti when it was already added --- test/benchgc/src/benchgc/mlir/util.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/benchgc/src/benchgc/mlir/util.py b/test/benchgc/src/benchgc/mlir/util.py index 287160243..89d995290 100644 --- a/test/benchgc/src/benchgc/mlir/util.py +++ b/test/benchgc/src/benchgc/mlir/util.py @@ -157,6 +157,8 @@ def get_kernel_func_from_module( def attch_dlti(module: ir.Module): + if module.operation.attributes["dlti.target_system_spec"] is not None: + return info = cpuinfo.get_cpu_info() l1_data_cache_size = info.get("l1_data_cache_size") l2_cache_size = info.get("l2_cache_size") From 9db3237c84062b14db84dbf0806cc52d26c123d8 Mon Sep 17 00:00:00 2001 From: "Xu, Rui" Date: Sun, 1 Sep 2024 23:05:22 -0700 Subject: [PATCH 25/38] update readme --- test/benchgc/README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/benchgc/README.md b/test/benchgc/README.md index 9f18cc398..bc31bf970 100644 --- a/test/benchgc/README.md +++ b/test/benchgc/README.md @@ -8,6 +8,8 @@ Benchgc is a tool used to verify the correctness and performance of graph compil * python >= 3.10 * torch >= 2.2 * Enable mlir python binding, Refer to [`python/README.md`](../../python/README.md) for detail +* Set the envs + * GC_NUM_THREADS [int] : the `num_threads` for dlti attr, default = 1 ## Build There are two ways for using benchgc From d714cf06d94522f151b68202a7f59e2bf74e9a15 Mon Sep 17 00:00:00 2001 From: "Xu, Rui" Date: Sun, 1 Sep 2024 23:24:47 -0700 Subject: [PATCH 26/38] update readme --- test/benchgc/README.md | 1 + test/benchgc/setup.py | 2 +- test/benchgc/src/benchgc/mlir/util.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/test/benchgc/README.md b/test/benchgc/README.md index bc31bf970..cd1e5d3eb 100644 --- a/test/benchgc/README.md +++ b/test/benchgc/README.md @@ -7,6 +7,7 @@ Benchgc is a tool used to verify the correctness and performance of graph compil ## Prerequisite * python >= 3.10 * torch >= 2.2 +* pip install py-cpuinfo * Enable mlir python binding, Refer to [`python/README.md`](../../python/README.md) for detail * Set the envs * GC_NUM_THREADS [int] : the `num_threads` for dlti attr, default = 1 diff --git a/test/benchgc/setup.py b/test/benchgc/setup.py index 3d67af539..1d3e67384 100644 --- a/test/benchgc/setup.py +++ b/test/benchgc/setup.py @@ -26,5 +26,5 @@ packages=setuptools.find_packages("src") + setuptools.find_namespace_packages("../../python_packages/gc_mlir_core"), package_data={"gc_mlir": ["_mlir_libs/*.so"]}, - install_requires=["torch", "numpy", "ml_dtypes"], + install_requires=["torch", "numpy", "py-cpuinfo"], ) diff --git a/test/benchgc/src/benchgc/mlir/util.py b/test/benchgc/src/benchgc/mlir/util.py index 89d995290..fb41e0f3c 100644 --- a/test/benchgc/src/benchgc/mlir/util.py +++ b/test/benchgc/src/benchgc/mlir/util.py @@ -157,7 +157,7 @@ def get_kernel_func_from_module( def attch_dlti(module: ir.Module): - if module.operation.attributes["dlti.target_system_spec"] is not None: + if "dlti.target_system_spec" in module.operation.attributes: return info = cpuinfo.get_cpu_info() l1_data_cache_size = info.get("l1_data_cache_size") From 6f34f0f5de0cc69a6c2ba50c12cf0b682bfac2f0 Mon Sep 17 00:00:00 2001 From: "Xu, Rui" Date: Sun, 1 Sep 2024 23:38:57 -0700 Subject: [PATCH 27/38] fix ci --- .github/workflows/build.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 0ca233a76..6050c2fc0 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -49,6 +49,7 @@ jobs: - name: Build and install benchgc working-directory: build run: | + pip install py-cpuinfo ninja benchgc pip uninstall -y benchgc || true pip install test/benchgc/dist/benchgc-*.whl From 8cb976fcd413c75ab1d43207447a94dcee94930c Mon Sep 17 00:00:00 2001 From: "Xu, Rui" Date: Sun, 1 Sep 2024 23:45:50 -0700 Subject: [PATCH 28/38] fix --- .github/workflows/build.yml | 1 - test/benchgc/setup.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 6050c2fc0..0ca233a76 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -49,7 +49,6 @@ jobs: - name: Build and install benchgc working-directory: build run: | - pip install py-cpuinfo ninja benchgc pip uninstall -y benchgc || true pip install test/benchgc/dist/benchgc-*.whl diff --git a/test/benchgc/setup.py b/test/benchgc/setup.py index 1d3e67384..dd3cc37fd 100644 --- a/test/benchgc/setup.py +++ b/test/benchgc/setup.py @@ -26,5 +26,5 @@ packages=setuptools.find_packages("src") + setuptools.find_namespace_packages("../../python_packages/gc_mlir_core"), package_data={"gc_mlir": ["_mlir_libs/*.so"]}, - install_requires=["torch", "numpy", "py-cpuinfo"], + install_requires=["torch", "numpy", "cpuinfo"], ) From 15d14f4ea4a49213878ce3d0704f7017c941da63 Mon Sep 17 00:00:00 2001 From: "Xu, Rui" Date: Mon, 2 Sep 2024 00:10:53 -0700 Subject: [PATCH 29/38] fix --- test/benchgc/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/benchgc/setup.py b/test/benchgc/setup.py index dd3cc37fd..1d3e67384 100644 --- a/test/benchgc/setup.py +++ b/test/benchgc/setup.py @@ -26,5 +26,5 @@ packages=setuptools.find_packages("src") + setuptools.find_namespace_packages("../../python_packages/gc_mlir_core"), package_data={"gc_mlir": ["_mlir_libs/*.so"]}, - install_requires=["torch", "numpy", "cpuinfo"], + install_requires=["torch", "numpy", "py-cpuinfo"], ) From 82521bb93db079065930d80cb0abfc960407c3e9 Mon Sep 17 00:00:00 2001 From: "Xu, Rui" Date: Mon, 2 Sep 2024 17:45:48 -0700 Subject: [PATCH 30/38] fix env name --- test/benchgc/README.md | 2 +- test/benchgc/src/benchgc/mlir/util.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/benchgc/README.md b/test/benchgc/README.md index cd1e5d3eb..98517638d 100644 --- a/test/benchgc/README.md +++ b/test/benchgc/README.md @@ -10,7 +10,7 @@ Benchgc is a tool used to verify the correctness and performance of graph compil * pip install py-cpuinfo * Enable mlir python binding, Refer to [`python/README.md`](../../python/README.md) for detail * Set the envs - * GC_NUM_THREADS [int] : the `num_threads` for dlti attr, default = 1 + * OMP_NUM_THREADS [int] : the `num_threads` for dlti attr, default = 1 ## Build There are two ways for using benchgc diff --git a/test/benchgc/src/benchgc/mlir/util.py b/test/benchgc/src/benchgc/mlir/util.py index fb41e0f3c..6b5b92f84 100644 --- a/test/benchgc/src/benchgc/mlir/util.py +++ b/test/benchgc/src/benchgc/mlir/util.py @@ -163,9 +163,9 @@ def attch_dlti(module: ir.Module): l1_data_cache_size = info.get("l1_data_cache_size") l2_cache_size = info.get("l2_cache_size") l3_cache_size = info.get("l3_cache_size") - if "GC_NUM_THREADS" not in os.environ: - print("GC_NUM_THREADS is not found, using 1 as default") - num_threads = os.environ.get("GC_NUM_THREADS", 1) + if "OMP_NUM_THREADS" not in os.environ: + print("OMP_NUM_THREADS is not found, using 1 as default") + num_threads = os.environ.get("OMP_NUM_THREADS", 1) flags = info.get("flags") max_vector_width = 64 for flag in flags: From fa5961149e86b1268b5acf8edb5f3260f683b0c7 Mon Sep 17 00:00:00 2001 From: "Xu, Rui" Date: Mon, 2 Sep 2024 17:56:24 -0700 Subject: [PATCH 31/38] add test print --- test/benchgc/src/benchgc/mlir/util.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/benchgc/src/benchgc/mlir/util.py b/test/benchgc/src/benchgc/mlir/util.py index 6b5b92f84..9649a33df 100644 --- a/test/benchgc/src/benchgc/mlir/util.py +++ b/test/benchgc/src/benchgc/mlir/util.py @@ -160,6 +160,7 @@ def attch_dlti(module: ir.Module): if "dlti.target_system_spec" in module.operation.attributes: return info = cpuinfo.get_cpu_info() + print(info) l1_data_cache_size = info.get("l1_data_cache_size") l2_cache_size = info.get("l2_cache_size") l3_cache_size = info.get("l3_cache_size") @@ -187,6 +188,7 @@ def attch_dlti(module: ir.Module): #dlti.dl_entry<"max_vector_width", {max_vector_width} : i64>> >}} {{}} """ + print(dlti_template) with module.context: template_module = ir.Module.parse(dlti_template) module.operation.attributes["dlti.target_system_spec"] = ( From b4aecffdc293b40d52002793c65d7305f6814782 Mon Sep 17 00:00:00 2001 From: "Xu, Rui" Date: Mon, 2 Sep 2024 18:26:22 -0700 Subject: [PATCH 32/38] add test print2 --- .github/workflows/build.yml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 0ca233a76..dd4e8273e 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -26,6 +26,12 @@ jobs: run: | echo LLVM_HASH=$(cat cmake/llvm-version.txt) >>$GITHUB_ENV + + - name: cpu-info + run: | + lscpu + cat /sys/devices/system/cpu/cpu0/cache/index0/size + - name: Fetch requirements for python binding uses: actions/checkout@v4 with: From 20a0e28b90201412e4aeb0058a9eb7b78d20f7ab Mon Sep 17 00:00:00 2001 From: "Xu, Rui" Date: Mon, 2 Sep 2024 19:48:52 -0700 Subject: [PATCH 33/38] test ci --- python/MainModule.cpp | 38 +++++++++++++++++++++++++++ test/benchgc/src/benchgc/mlir/util.py | 2 ++ 2 files changed, 40 insertions(+) diff --git a/python/MainModule.cpp b/python/MainModule.cpp index 10f000bc0..6a9df3c26 100644 --- a/python/MainModule.cpp +++ b/python/MainModule.cpp @@ -21,6 +21,41 @@ #include "gc-c/Passes.h" #include "mlir/Bindings/Python/PybindAdaptors.h" + +#include +#include + +void get_l1_data_cache_size() { + uint32_t eax, ebx, ecx, edx; + + // Query the cache information using CPUID with EAX=4 and ECX=1 (L1 data cache) + eax = 4; // Cache information + ecx = 0; // Cache level (0 for L1 data cache) + + __asm__ __volatile__( + "cpuid" + : "=a" (eax), "=b" (ebx), "=c" (ecx), "=d" (edx) + : "a" (eax), "c" (ecx) + ); + + // Extract cache size information + uint32_t cache_type = eax & 0x1F; + if (cache_type != 1) { // 1 indicates data cache + printf("No L1 data cache\n"); + return; + } + + uint32_t cache_level = (eax >> 5) & 0x7; + uint32_t cache_sets = ecx + 1; + uint32_t cache_coherency_line_size = (ebx & 0xFFF) + 1; + uint32_t cache_partitions = ((ebx >> 12) & 0x3FF) + 1; + uint32_t cache_ways_of_associativity = ((ebx >> 22) & 0x3FF) + 1; + + uint32_t cache_size = cache_ways_of_associativity * cache_partitions * cache_coherency_line_size * cache_sets; + + printf("L%d Data Cache Size: %u KB\n", cache_level, cache_size / 1024); +} + PYBIND11_MODULE(_gc_mlir, m) { m.doc() = "Graph-compiler MLIR Python binding"; @@ -57,4 +92,7 @@ PYBIND11_MODULE(_gc_mlir, m) { } }, py::arg("context") = py::none(), py::arg("load") = true); + + + cpuruntimeM.def("get_l1_data_cache_size", &get_l1_data_cache_size, "---"); } \ No newline at end of file diff --git a/test/benchgc/src/benchgc/mlir/util.py b/test/benchgc/src/benchgc/mlir/util.py index 9649a33df..9a567958f 100644 --- a/test/benchgc/src/benchgc/mlir/util.py +++ b/test/benchgc/src/benchgc/mlir/util.py @@ -160,6 +160,8 @@ def attch_dlti(module: ir.Module): if "dlti.target_system_spec" in module.operation.attributes: return info = cpuinfo.get_cpu_info() + from gc_mlir.dialects import cpuruntime + cpuruntime.get_l1_data_cache_size() print(info) l1_data_cache_size = info.get("l1_data_cache_size") l2_cache_size = info.get("l2_cache_size") From 67d155c1b4639120a97b25f278818de7f1464959 Mon Sep 17 00:00:00 2001 From: "Xu, Rui" Date: Mon, 2 Sep 2024 22:41:10 -0700 Subject: [PATCH 34/38] test ci cpu --- python/MainModule.cpp | 74 +++++++++++++++------------ test/benchgc/src/benchgc/mlir/util.py | 2 +- 2 files changed, 42 insertions(+), 34 deletions(-) diff --git a/python/MainModule.cpp b/python/MainModule.cpp index 6a9df3c26..aea243da5 100644 --- a/python/MainModule.cpp +++ b/python/MainModule.cpp @@ -21,39 +21,48 @@ #include "gc-c/Passes.h" #include "mlir/Bindings/Python/PybindAdaptors.h" - -#include #include - -void get_l1_data_cache_size() { - uint32_t eax, ebx, ecx, edx; - - // Query the cache information using CPUID with EAX=4 and ECX=1 (L1 data cache) - eax = 4; // Cache information - ecx = 0; // Cache level (0 for L1 data cache) - - __asm__ __volatile__( - "cpuid" - : "=a" (eax), "=b" (ebx), "=c" (ecx), "=d" (edx) - : "a" (eax), "c" (ecx) - ); - - // Extract cache size information - uint32_t cache_type = eax & 0x1F; - if (cache_type != 1) { // 1 indicates data cache - printf("No L1 data cache\n"); - return; +#include + +#include + +// 使用GCC内联汇编的CPUID函数 +void cpuid(int info[4], int InfoType, int ECXValue) { + __asm__ __volatile__("cpuid" + : "=a"(info[0]), "=b"(info[1]), "=c"(info[2]), + "=d"(info[3]) + : "a"(InfoType), "c"(ECXValue)); +} + +void get_cpu_info() { + int info[4]; + cpuid(info, 0, 0); // 获取最大的CPUID功能号 + int nIds = info[0]; + + for (int i = 0; i <= nIds; ++i) { + cpuid(info, 4, i); // 查询缓存参数 + int cacheType = info[0] & 0x1F; + if (cacheType == 0) { + break; // 没有更多的缓存级别 } - - uint32_t cache_level = (eax >> 5) & 0x7; - uint32_t cache_sets = ecx + 1; - uint32_t cache_coherency_line_size = (ebx & 0xFFF) + 1; - uint32_t cache_partitions = ((ebx >> 12) & 0x3FF) + 1; - uint32_t cache_ways_of_associativity = ((ebx >> 22) & 0x3FF) + 1; - - uint32_t cache_size = cache_ways_of_associativity * cache_partitions * cache_coherency_line_size * cache_sets; - - printf("L%d Data Cache Size: %u KB\n", cache_level, cache_size / 1024); + int cacheLevel = (info[0] >> 5) & 0x7; + int cacheLinesPerTag = ((info[1] >> 0) & 0xFFF) + 1; + int cacheAssociativity = ((info[1] >> 12) & 0x3FF) + 1; + int cachePartitions = ((info[1] >> 22) & 0x3FF) + 1; + int cacheSets = info[2] + 1; + int cacheSize = + cacheLinesPerTag * cacheAssociativity * cachePartitions * cacheSets; + + std::cout << "L" << cacheLevel << " "; + if (cacheType == 1) { + std::cout << "Data Cache: "; + } else if (cacheType == 2) { + std::cout << "Instruction Cache: "; + } else if (cacheType == 3) { + std::cout << "Unified Cache: "; + } + std::cout << cacheSize << " bytes" << std::endl; + } } PYBIND11_MODULE(_gc_mlir, m) { @@ -93,6 +102,5 @@ PYBIND11_MODULE(_gc_mlir, m) { }, py::arg("context") = py::none(), py::arg("load") = true); - - cpuruntimeM.def("get_l1_data_cache_size", &get_l1_data_cache_size, "---"); + cpuruntimeM.def("get_cpu_info", &get_cpu_info, "---"); } \ No newline at end of file diff --git a/test/benchgc/src/benchgc/mlir/util.py b/test/benchgc/src/benchgc/mlir/util.py index 9a567958f..c8d10cc93 100644 --- a/test/benchgc/src/benchgc/mlir/util.py +++ b/test/benchgc/src/benchgc/mlir/util.py @@ -161,7 +161,7 @@ def attch_dlti(module: ir.Module): return info = cpuinfo.get_cpu_info() from gc_mlir.dialects import cpuruntime - cpuruntime.get_l1_data_cache_size() + cpuruntime.get_cpu_info() print(info) l1_data_cache_size = info.get("l1_data_cache_size") l2_cache_size = info.get("l2_cache_size") From 1e7c8355fe61751c5d87ecc9bc7c41f12a1c6e79 Mon Sep 17 00:00:00 2001 From: "Xu, Rui" Date: Thu, 5 Sep 2024 00:52:49 -0700 Subject: [PATCH 35/38] add new cpuinfo --- .github/workflows/build.yml | 9 +----- python/CMakeLists.txt | 9 ++++++ python/MainModule.cpp | 46 --------------------------- test/benchgc/setup.py | 2 +- test/benchgc/src/benchgc/__main__.py | 16 +++++++++- test/benchgc/src/benchgc/mlir/util.py | 42 +++++++++++++----------- 6 files changed, 49 insertions(+), 75 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index dd4e8273e..4b8c60d31 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -24,14 +24,7 @@ jobs: - name: Set LLVM hash run: | - echo LLVM_HASH=$(cat cmake/llvm-version.txt) >>$GITHUB_ENV - - - - name: cpu-info - run: | - lscpu - cat /sys/devices/system/cpu/cpu0/cache/index0/size - + echo LLVM_HASH=$(cat cmake/llvm-version.txt) >>$GITHUB_ENV - name: Fetch requirements for python binding uses: actions/checkout@v4 with: diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 505062ecd..355aba91f 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -49,6 +49,8 @@ declare_mlir_python_sources(GcPythonSources.Common __init__.py graph_compiler.py dialects/__init__.py + tools/__init__.py + tools/cpuinfo.py # init hooks _mlir_libs/_site_initialize_0.py ) @@ -86,6 +88,13 @@ declare_mlir_python_extension(GcPythonSources.Extension GcCAPI ) +declare_mlir_python_extension(GcPythonSources.CpuInfoExtension + MODULE_NAME _cpuinfo + ADD_TO_PARENT GcPythonSources + SOURCES + CPUInfo.cpp +) + ################################################################################ # Common CAPI ################################################################################ diff --git a/python/MainModule.cpp b/python/MainModule.cpp index aea243da5..10f000bc0 100644 --- a/python/MainModule.cpp +++ b/python/MainModule.cpp @@ -21,50 +21,6 @@ #include "gc-c/Passes.h" #include "mlir/Bindings/Python/PybindAdaptors.h" -#include -#include - -#include - -// 使用GCC内联汇编的CPUID函数 -void cpuid(int info[4], int InfoType, int ECXValue) { - __asm__ __volatile__("cpuid" - : "=a"(info[0]), "=b"(info[1]), "=c"(info[2]), - "=d"(info[3]) - : "a"(InfoType), "c"(ECXValue)); -} - -void get_cpu_info() { - int info[4]; - cpuid(info, 0, 0); // 获取最大的CPUID功能号 - int nIds = info[0]; - - for (int i = 0; i <= nIds; ++i) { - cpuid(info, 4, i); // 查询缓存参数 - int cacheType = info[0] & 0x1F; - if (cacheType == 0) { - break; // 没有更多的缓存级别 - } - int cacheLevel = (info[0] >> 5) & 0x7; - int cacheLinesPerTag = ((info[1] >> 0) & 0xFFF) + 1; - int cacheAssociativity = ((info[1] >> 12) & 0x3FF) + 1; - int cachePartitions = ((info[1] >> 22) & 0x3FF) + 1; - int cacheSets = info[2] + 1; - int cacheSize = - cacheLinesPerTag * cacheAssociativity * cachePartitions * cacheSets; - - std::cout << "L" << cacheLevel << " "; - if (cacheType == 1) { - std::cout << "Data Cache: "; - } else if (cacheType == 2) { - std::cout << "Instruction Cache: "; - } else if (cacheType == 3) { - std::cout << "Unified Cache: "; - } - std::cout << cacheSize << " bytes" << std::endl; - } -} - PYBIND11_MODULE(_gc_mlir, m) { m.doc() = "Graph-compiler MLIR Python binding"; @@ -101,6 +57,4 @@ PYBIND11_MODULE(_gc_mlir, m) { } }, py::arg("context") = py::none(), py::arg("load") = true); - - cpuruntimeM.def("get_cpu_info", &get_cpu_info, "---"); } \ No newline at end of file diff --git a/test/benchgc/setup.py b/test/benchgc/setup.py index 1d3e67384..a9b49e6f0 100644 --- a/test/benchgc/setup.py +++ b/test/benchgc/setup.py @@ -26,5 +26,5 @@ packages=setuptools.find_packages("src") + setuptools.find_namespace_packages("../../python_packages/gc_mlir_core"), package_data={"gc_mlir": ["_mlir_libs/*.so"]}, - install_requires=["torch", "numpy", "py-cpuinfo"], + install_requires=["torch", "numpy"], ) diff --git a/test/benchgc/src/benchgc/__main__.py b/test/benchgc/src/benchgc/__main__.py index ed101f46b..dff0516d5 100644 --- a/test/benchgc/src/benchgc/__main__.py +++ b/test/benchgc/src/benchgc/__main__.py @@ -124,6 +124,20 @@ def add_common_options(parser: argparse.ArgumentParser): help="if we need print the ir during the pass-pipeline", ) + parser.add_argument( + "--cpu_cache_sizes", + required=False, + help="set the cpu cache sizes, format: L1:L2:L3", + type=str, + ) + + parser.add_argument( + "--max_vector_width", + required=False, + help="set the cpu max_vector_width", + type=int, + ) + if parser.parse_known_args()[0].driver == "linalg": parser.add_argument( "--cast", @@ -264,7 +278,7 @@ def get_module_and_args(flags): for arg in args: arg.print_verbose(flags.verbose) - benchgc.mlir.util.attch_dlti(module) + benchgc.mlir.util.attch_dlti(flags, module) if flags.verbose >= benchgc.util.MODULE_VERBOSE: print(module) diff --git a/test/benchgc/src/benchgc/mlir/util.py b/test/benchgc/src/benchgc/mlir/util.py index c8d10cc93..1c3ecb882 100644 --- a/test/benchgc/src/benchgc/mlir/util.py +++ b/test/benchgc/src/benchgc/mlir/util.py @@ -18,10 +18,10 @@ import os from typing import Any, List -import cpuinfo import torch from gc_mlir import ir from gc_mlir.dialects import arith, func, memref +from gc_mlir.tools import cpuinfo # calling python binding consumes a lot of time e.g. get_name() @@ -156,28 +156,32 @@ def get_kernel_func_from_module( raise ValueError("can not find the entry function") -def attch_dlti(module: ir.Module): +def attch_dlti(flags, module: ir.Module): + # the moudle already had dlti attr if "dlti.target_system_spec" in module.operation.attributes: return - info = cpuinfo.get_cpu_info() - from gc_mlir.dialects import cpuruntime - cpuruntime.get_cpu_info() - print(info) - l1_data_cache_size = info.get("l1_data_cache_size") - l2_cache_size = info.get("l2_cache_size") - l3_cache_size = info.get("l3_cache_size") + if flags.cpu_cache_sizes: + caches_sizes = [int(x) for x in flags.cpu_cache_sizes.strip().split(":")] + else: + caches_sizes = cpuinfo.get_cache_sizes() + if not caches_sizes or len(caches_sizes) != 3: + print( + "Failed to get CPU cache sizes, please added them manually br --cpu_cache_sizes" + ) + return + if flags.max_vector_width: + max_vector_width = flags.max_vector_width + else: + max_vector_width = cpuinfo.get_max_vector_width() + if not max_vector_width: + print( + "Failed to get CPU max vector width, please added them manually br --max_vector_width" + ) + return + l1_data_cache_size, l2_cache_size, l3_cache_size = caches_sizes if "OMP_NUM_THREADS" not in os.environ: print("OMP_NUM_THREADS is not found, using 1 as default") num_threads = os.environ.get("OMP_NUM_THREADS", 1) - flags = info.get("flags") - max_vector_width = 64 - for flag in flags: - if "avx512f" == flag: - max_vector_width = max(512, max_vector_width) - elif "avx2" == flag or "avx" == flag: - max_vector_width = max(256, max_vector_width) - elif "sse" in flag: - max_vector_width = max(128, max_vector_width) dlti_template = f""" module attributes {{ @@ -195,4 +199,4 @@ def attch_dlti(module: ir.Module): template_module = ir.Module.parse(dlti_template) module.operation.attributes["dlti.target_system_spec"] = ( template_module.operation.attributes["dlti.target_system_spec"] - ) \ No newline at end of file + ) From 21860b4d4a822f2d0c66a0146d6e05e12587b664 Mon Sep 17 00:00:00 2001 From: "Xu, Rui" Date: Thu, 5 Sep 2024 00:58:48 -0700 Subject: [PATCH 36/38] fix --- .github/workflows/build.yml | 3 +- python/CPUInfo.cpp | 66 +++++++++++++++++++++++++++ python/gc_mlir/tools/__init__.py | 7 +++ python/gc_mlir/tools/cpuinfo.py | 26 +++++++++++ test/benchgc/src/benchgc/mlir/util.py | 3 -- 5 files changed, 101 insertions(+), 4 deletions(-) create mode 100644 python/CPUInfo.cpp create mode 100644 python/gc_mlir/tools/__init__.py create mode 100644 python/gc_mlir/tools/cpuinfo.py diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 4b8c60d31..1159667bb 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -24,7 +24,8 @@ jobs: - name: Set LLVM hash run: | - echo LLVM_HASH=$(cat cmake/llvm-version.txt) >>$GITHUB_ENV + echo LLVM_HASH=$(cat cmake/llvm-version.txt) >>$GITHUB_ENV + - name: Fetch requirements for python binding uses: actions/checkout@v4 with: diff --git a/python/CPUInfo.cpp b/python/CPUInfo.cpp new file mode 100644 index 000000000..c9acefa2a --- /dev/null +++ b/python/CPUInfo.cpp @@ -0,0 +1,66 @@ +#include "mlir/Bindings/Python/PybindAdaptors.h" + +#if defined(__x86_64__) || defined(_M_X64) || defined(__i386__) || \ + defined(_M_IX86) +// x86 or x86_64 specific code +void cpuid(int info[4], int leaf, int subleaf) { + __asm__ __volatile__("cpuid" + : "=a"(info[0]), "=b"(info[1]), "=c"(info[2]), + "=d"(info[3]) + : "a"(leaf), "c"(subleaf)); +} + +std::vector getCacheSizes() { + int info[4]; + cpuid(info, 0, 0); + int nIds = info[0]; + int caches[3] = {}; + for (int i = 0; i <= nIds; ++i) { + cpuid(info, 4, i); + int cacheType = info[0] & 0x1F; + if (cacheType == 0) { + break; + } + + int cacheLevel = (info[0] >> 5) & 0x7; + int cacheLinesPerTag = ((info[1] >> 0) & 0xFFF) + 1; + int cacheAssociativity = ((info[1] >> 12) & 0x3FF) + 1; + int cachePartitions = ((info[1] >> 22) & 0x3FF) + 1; + int cacheSets = info[2] + 1; + int cacheSize = + cacheLinesPerTag * cacheAssociativity * cachePartitions * cacheSets; + if (cacheLevel >= 1 && cacheLevel <= 3) { + caches[cacheLevel - 1] = cacheSize; + } + } + return std::vector(std::begin(caches), std::end(caches)); +} + +bool isFeatureSupported(int function_id, int register_idx, int bit) { + int info[4]; + cpuid(info, function_id, 0); + return (info[register_idx] & (1 << bit)) != 0; +} + +int getMaxVectorWidth() { + if (isFeatureSupported(7, 1, 16)) { // Check for AVX-512F support + return 512; + } else if (isFeatureSupported(1, 2, 28)) { // Check for AVX support + return 256; + } else if (isFeatureSupported(1, 3, 25)) { // Check for SSE support + return 128; + } + return 64; // Default to 64 if none of the above features are supported +} +#else +std::vector getCacheSizes() { return {}; } + +int getMaxVectorWidth { return 0; } +#endif + +PYBIND11_MODULE(_cpuinfo, m) { + m.doc() = "Graph-compiler MLIR Python binding"; + m.def("get_cache_sizes", &getCacheSizes, "Get CPU L1,L2,L3 cache size"); + m.def("get_max_vector_width", &getMaxVectorWidth, + "Get CPU supported max vector width"); +} \ No newline at end of file diff --git a/python/gc_mlir/tools/__init__.py b/python/gc_mlir/tools/__init__.py new file mode 100644 index 000000000..172887970 --- /dev/null +++ b/python/gc_mlir/tools/__init__.py @@ -0,0 +1,7 @@ +# ===-- __init__.py - init ------------------------------------*- Python -*-===# +# +# This file is licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# +# ===-----------------------------------------------------------------------===# diff --git a/python/gc_mlir/tools/cpuinfo.py b/python/gc_mlir/tools/cpuinfo.py new file mode 100644 index 000000000..7833ece68 --- /dev/null +++ b/python/gc_mlir/tools/cpuinfo.py @@ -0,0 +1,26 @@ +# ===-- cpuinfo.py - Getting the CPU info ---------------------*- Python -*-===# +# +# This file is licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# +# ===-----------------------------------------------------------------------===# + +from .._mlir_libs import _cpuinfo + +_cache_sizes = [] +_max_vector_width = None + + +def get_cache_sizes(): + global _cache_sizes + if not _cache_sizes: + _cache_sizes = _cpuinfo.get_cache_sizes() + return _cache_sizes + + +def get_max_vector_width(): + global _max_vector_width + if _max_vector_width is None: + _max_vector_width = _cpuinfo.get_max_vector_width() + return _max_vector_width diff --git a/test/benchgc/src/benchgc/mlir/util.py b/test/benchgc/src/benchgc/mlir/util.py index 46e6a2c4f..89d23202f 100644 --- a/test/benchgc/src/benchgc/mlir/util.py +++ b/test/benchgc/src/benchgc/mlir/util.py @@ -21,10 +21,7 @@ import torch from gc_mlir import ir from gc_mlir.dialects import arith, func, memref -<<<<<<< HEAD from gc_mlir.tools import cpuinfo -======= ->>>>>>> main # calling python binding consumes a lot of time e.g. get_name() From 6db244e1859b07840ddecd1664e7cc95937ee4a7 Mon Sep 17 00:00:00 2001 From: "Xu, Rui" Date: Thu, 5 Sep 2024 01:01:50 -0700 Subject: [PATCH 37/38] fix --- .github/workflows/build.yml | 2 +- python/CPUInfo.cpp | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 1159667bb..0ca233a76 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -25,7 +25,7 @@ jobs: - name: Set LLVM hash run: | echo LLVM_HASH=$(cat cmake/llvm-version.txt) >>$GITHUB_ENV - + - name: Fetch requirements for python binding uses: actions/checkout@v4 with: diff --git a/python/CPUInfo.cpp b/python/CPUInfo.cpp index c9acefa2a..c643f5a12 100644 --- a/python/CPUInfo.cpp +++ b/python/CPUInfo.cpp @@ -1,3 +1,21 @@ +/* + * Copyright (C) 2024 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions + * and limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + #include "mlir/Bindings/Python/PybindAdaptors.h" #if defined(__x86_64__) || defined(_M_X64) || defined(__i386__) || \ From 2ab9a48fe33abb3ec3bf68ca3ac024e0040cdb7d Mon Sep 17 00:00:00 2001 From: "Xu, Rui" Date: Sun, 8 Sep 2024 23:42:12 -0700 Subject: [PATCH 38/38] fix --- python/CPUInfo.cpp | 5 ++++- test/benchgc/README.md | 6 ++++++ test/benchgc/src/benchgc/__main__.py | 2 +- test/benchgc/src/benchgc/mlir/util.py | 3 ++- 4 files changed, 13 insertions(+), 3 deletions(-) diff --git a/python/CPUInfo.cpp b/python/CPUInfo.cpp index c643f5a12..98cecc9e1 100644 --- a/python/CPUInfo.cpp +++ b/python/CPUInfo.cpp @@ -39,7 +39,10 @@ std::vector getCacheSizes() { if (cacheType == 0) { break; } - + if (cacheType == 2) { + // skip instruction cache + continue; + } int cacheLevel = (info[0] >> 5) & 0x7; int cacheLinesPerTag = ((info[1] >> 0) & 0xFFF) + 1; int cacheAssociativity = ((info[1] >> 12) & 0x3FF) + 1; diff --git a/test/benchgc/README.md b/test/benchgc/README.md index c0444aa37..239105c82 100644 --- a/test/benchgc/README.md +++ b/test/benchgc/README.md @@ -109,6 +109,12 @@ module { | Pytorch tensor dump | F | dump filename | | Benchdnn driver | D | driver_name[:driver filling parameter]* | +### --cpu_cache_sizes, --max_vector_width +* BenchGC will automatically obtain target info and add the DLTI attr to the IR +* In some cases, if the system info obtained by BenchGC is not accurate, you can specify the relevant attributes for BenchGC through these options. +* --cpu_cache_sizes: cpu cache sizes in bytes, format: L1:L2:L3, example: `--cpu_cache_sizes 49152:2097152:110100480` +* --max_vector_width: the maximum width of vector registers available in a CPU, example `--max_vector_width ` + #### Benchdnn driver filling | driver_name | driver filling parameter | diff --git a/test/benchgc/src/benchgc/__main__.py b/test/benchgc/src/benchgc/__main__.py index dff0516d5..1f0d880aa 100644 --- a/test/benchgc/src/benchgc/__main__.py +++ b/test/benchgc/src/benchgc/__main__.py @@ -278,7 +278,7 @@ def get_module_and_args(flags): for arg in args: arg.print_verbose(flags.verbose) - benchgc.mlir.util.attch_dlti(flags, module) + benchgc.mlir.util.attach_dlti(flags, module) if flags.verbose >= benchgc.util.MODULE_VERBOSE: print(module) diff --git a/test/benchgc/src/benchgc/mlir/util.py b/test/benchgc/src/benchgc/mlir/util.py index 89d23202f..db500fc26 100644 --- a/test/benchgc/src/benchgc/mlir/util.py +++ b/test/benchgc/src/benchgc/mlir/util.py @@ -14,6 +14,7 @@ # limitations under the License. ################################################################################ +import argparse import ctypes import os from typing import Any, List @@ -150,7 +151,7 @@ def get_kernel_func_from_module( raise ValueError("can not find the entry function") -def attch_dlti(flags, module: ir.Module): +def attach_dlti(flags: argparse.Namespace, module: ir.Module): # the moudle already had dlti attr if "dlti.target_system_spec" in module.operation.attributes: return