Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] mhlo.uniform_quantize/mhlo.uniform_dequantize op can't be translated to XLA HLO #9291

Open
cryptodeal opened this issue Feb 7, 2024 · 5 comments

Comments

@cryptodeal
Copy link

Trying to compile a PJRT client that calls either stablehlo.uniform_quantize/stablehlo.uniform_dequantize fails to compile; e.g. the following snippet should compile:

func.func @main(%arg: tensor<4xf32>) -> tensor<4xf32> {
    %0 = stablehlo.uniform_quantize %arg {operandSegmentSizes = array<i32: 1>} : (tensor<4xf32>) -> tensor<4x!quant.uniform<i8:f32, 3.400000e+01:16>>
    %1 = stablehlo.uniform_dequantize %0 : (tensor<4x!quant.uniform<i8:f32, 34.0:16>>) -> tensor<4xf32>
  func.return %1 : tensor<4xf32>
}

Bug reproduction here: https://github.com/cryptodeal/xla

To run, clone the fork:

cd xla/examples/quant
bazelisk test :stablehlo_compile_test --test_output=all --nocheck_visibility

Running the bug reproduction yields the following output:

==================== Test output for //xla/examples/quant:stablehlo_compile_test:
Running main() from gmock_main.cc
[==========] Running 1 test from 1 test suite.
[----------] Global test environment set-up.
[----------] 1 test from StableHloQuantTest
[ RUN      ] StableHloQuantTest.LoadAndRunCpuExecutable
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1707335314.922530 17625729 service.cc:145] XLA service 0x6000007b4700 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1707335314.922607 17625729 service.cc:153]   StreamExecutor device (0): Host, Default Version
Loaded StableHLO program from /private/var/tmp/_bazel_cryptodeal/79529a3a6255d7cff0ce44cfe0b03ab6/execroot/xla/bazel-out/darwin_arm64-opt/bin/xla/examples/quant/stablehlo_compile_test.runfiles/xla/xla/examples/quant/stablehlo_quant.mlir:
func.func @main(%arg: tensor<4xf32>) -> tensor<4xf32> {
    %0 = stablehlo.uniform_quantize %arg {operandSegmentSizes = array<i32: 1>} : (tensor<4xf32>) -> tensor<4x!quant.uniform<i8:f32, 3.400000e+01:16>>
    %1 = stablehlo.uniform_dequantize %0 : (tensor<4x!quant.uniform<i8:f32, 34.0:16>>) -> tensor<4xf32>
  func.return %1 : tensor<4xf32>
}

xla/examples/quant/stablehlo_compile_test.cc:107: Failure
Value of: _status_or_value14.status().ok()
  Actual: false
Expected: true
UNKNOWN: -:2:10: error: 'mhlo.uniform_quantize' op can't be translated to XLA HLO
-:2:10: note: see current operation: %0 = "mhlo.uniform_quantize"(%arg0) {operandSegmentSizes = array<i32: 1>} : (tensor<4xf32>) -> tensor<4x!quant.uniform<i8:f32, 3.400000e+01:16>>

[  FAILED  ] StableHloQuantTest.LoadAndRunCpuExecutable (7 ms)
[----------] 1 test from StableHloQuantTest (7 ms total)

[----------] Global test environment tear-down
[==========] 1 test from 1 test suite ran. (8 ms total)
[  PASSED  ] 0 tests.
[  FAILED  ] 1 test, listed below:
[  FAILED  ] StableHloQuantTest.LoadAndRunCpuExecutable

 1 FAILED TEST
================================================================================
@cheshire
Copy link
Member

cheshire commented Feb 8, 2024

@GleasonK for more context, but overall quantize/dequantize ops are not supported in the OpenXLA compiler.

@GleasonK
Copy link
Member

GleasonK commented Feb 8, 2024

This is expected for XLA compilation - there is no compiler support for quantized ops/types, but all the underlying math can be represented in HLO. What we need is a pass which decomposes these ops into the necessary int8 computations. Such a pass exists in the TF repo: convert_mhlo_quant_to_int.cc but I'm not sure of the current status (limited/experimental support, production ready, very tied to a specific impl, etc). I'll follow up with the pass authors to see if this is something we can upstream to openxla/xla to allow for lowering programs with quantized types to HLO.

@GleasonK
Copy link
Member

GleasonK commented Feb 8, 2024

Also I should confirm: Is your goal just to compile / execute that program? Or are you trying to roundtrip this program between HLO/StableHLO? My above idea is fairly uni-directional and will help with the "compile and execute" workflow, roundtripping is a slightly different requirement

@cryptodeal
Copy link
Author

Also I should confirm: Is your goal just to compile / execute that program? Or are you trying to roundtrip this program between HLO/StableHLO? My above idea is fairly uni-directional and will help with the "compile and execute" workflow, roundtripping is a slightly different requirement

The goal is to compile and execute that logic from the program.

@GleasonK
Copy link
Member

GleasonK commented Feb 8, 2024

In that case the above pass should work! I've reached out to the maintainers to ask about upstreaming to openxla/xla. Should have word back soon, if that one turns out to be impl specific we can add our own decompositions. Once we have decompositions, we'll be able to improve the compilation support for programs with quantized types.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants