-
Notifications
You must be signed in to change notification settings - Fork 366
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] mhlo.uniform_quantize
/mhlo.uniform_dequantize
op can't be translated to XLA HLO
#9291
Comments
@GleasonK for more context, but overall quantize/dequantize ops are not supported in the OpenXLA compiler. |
This is expected for XLA compilation - there is no compiler support for quantized ops/types, but all the underlying math can be represented in HLO. What we need is a pass which decomposes these ops into the necessary int8 computations. Such a pass exists in the TF repo: convert_mhlo_quant_to_int.cc but I'm not sure of the current status (limited/experimental support, production ready, very tied to a specific impl, etc). I'll follow up with the pass authors to see if this is something we can upstream to openxla/xla to allow for lowering programs with quantized types to HLO. |
Also I should confirm: Is your goal just to compile / execute that program? Or are you trying to roundtrip this program between HLO/StableHLO? My above idea is fairly uni-directional and will help with the "compile and execute" workflow, roundtripping is a slightly different requirement |
The goal is to compile and execute that logic from the program. |
In that case the above pass should work! I've reached out to the maintainers to ask about upstreaming to openxla/xla. Should have word back soon, if that one turns out to be impl specific we can add our own decompositions. Once we have decompositions, we'll be able to improve the compilation support for programs with quantized types. |
Trying to compile a PJRT client that calls either
stablehlo.uniform_quantize
/stablehlo.uniform_dequantize
fails to compile; e.g. the following snippet should compile:Bug reproduction here: https://github.com/cryptodeal/xla
To run, clone the fork:
Running the bug reproduction yields the following output:
The text was updated successfully, but these errors were encountered: