-
Notifications
You must be signed in to change notification settings - Fork 400
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
XLA incorrectly optimizes scalar bool*float to select() #2438
Comments
Internal tracking bug for this issue: b/245348010, related to an earlier report: google/jax#12233 |
Fixing this has a significant negative performance impact on TPUs. Additionally, according to the internal bug, many users rely on the fact that Because of this, I don't think this is worth fixing. |
@reedwm Those are fair concerns, but I'd argue that this is "unsafe optimisation" at best (just miscompilation at worst), and would better be put behind a flag if straight-up fixing it is not possible. |
XLA in general does optimizations assuming floating-point values will not overflow or be NaN. These are ""unsafe optimisations" in the sense that XLA's behavior matches gcc's We could add a @soraros do you know of a particular case where this optimization causes issues in practice, or is this more of a general concern about unsafe optimizations? |
@reedwm My concerns are more on the general side, as how much semantic preservation guarantee does XLA provide given a piece of, say, StableHLO.
This is of course expected, but how does it apply in the case of Another question might be, for a frontend like JAX, what's the suggested way to generate IR that defeats this optimisation? Heterogeneous multiplication in JAX is currently implemented roughly as def mul(a, b):
dtype = dtypes.join(a.dtype, b.dtype)
if a.dtype < new_dtype:
a = stablehlo.convert(a, dtype)
if b.dtype < new_dtype:
b = stablehlo.convert(b, dtype)
return stablehlo.multiply(a, b) which generates for module @mul {
func.func public @main(%arg0: tensor<i1>, %arg1: tensor<f32>) -> tensor<f32> {
%0 = stablehlo.convert %arg0 : (tensor<i1>) -> tensor<f32>
%1 = stablehlo.multiply %0, %arg1 : tensor<f32>
return %1 : tensor<f32>
}
} I can think of one way which works with the current version of XLA that is (I hope you would agree with me) quite ugly: module @mul_bool_float {
func.func public @main(%arg0: tensor<3xi1>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
%0 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%1 = stablehlo.broadcast_in_dim %0, dims = [] : (tensor<f32>) -> tensor<3xf32>
%2 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%3 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor<f32>) -> tensor<3xf32>
%4 = stablehlo.select %arg0, %1, %3 : tensor<3xi1>, tensor<3xf32>
%5 = stablehlo.multiply %4, %arg1 : tensor<3xf32>
return %5 : tensor<3xf32>
}
} |
As you mentioned, the bool is converted to float in frontends like JAX. Saying " I agree that this optimization "hits differently" than some other unsafe optimizations, because the input has a NaN which is not propagated to the output. I don't know of any other optimizations that drop NaNs (they might exist, but I couldn't find anything with a quick search), but there are optimizations that introduce NaN/Inf where they shouldn't. Because fixing this would slow down TPUs, I don't think it's worth fixing since it's unlikely to negatively affect any models. @blakehechtman, do you have any opinions on whether this optimization should be done? To create IR to defeat this optimization, your def multiply_bool_and_float(b, f):
result = b * f
return jnp.where(jnp.isfinite(f), result, jnp.nan) @burmako it might be worth documenting in the StableHLO spec that the precise numeric behavior of all floating-point operations is implementation-dependent, and that optimizations may introduce or drop Infs/NaNs in places where they would not occur without optimizations. |
@reedwm Sorry if I sound unreasonably paranoid or stubborn, and thank you for your replies.
I think that's exactly why that's troublesome. Suppose I have a function
there is no good way to program defensively on the producer (
Could we at least disable them on other platforms? |
I've had some more time to think about the proposed changes to the StableHLO spec, and I am afraid to say that I'm not convinced that allowing such aggressive optimizations, is the best approach. One of my concerns is that this could equally enable conformant implementations other than XLA to do mis-compilations. Another concern is (which may be applicable to allowing any aggressive/unsafe optimization to be documented in the spec): In the event that such an optimization is dropped or supported under a special flag, which btw are all viable options, how should that event be reflected in the spec: drop it in the absence of use-case or let the other implementations interpret it their own way? Anyway, I'm open to discussing this further, and I'm happy to help explore other options. \cc @burmako |
@soraros I agree As for disabling this optimization only on non-TPU platforms: I don't like this because it makes behavior among different devices inconsistent. Since some users (incorrectly) rely on this optimization, if the optimization was only done on the TPU, it would be difficult for such users to migrate to other devices. @sdasgup3 this wouldn't be considered a mis-compilation if the spec documents it is OK. As for having a special flag to enable/disable this optimization: the spec doesn't need to mention any flags at all. It just would need to state that operations which normally would return NaN by the IEEE-754 spec are free to return any other value instead, and which value is returned is implementation-defined. Anyway, it's not up to me whether we disable this optimization on XLA-TPU. I don't think XLA-TPU developers can be persuaded to take a performance hit by disabling this optimization, even if the StableHLO spec ends up stating the optmization is disallowed, but @blakehechtman please give your thoughts here. Since I want XLA-GPU to be consistent with XLA-TPU, I'm unwilling to make this change in XLA-GPU as well. I don't think we should add any flags to disable this behavior unless there are non-hypothetical models which require the optimization to be disabled. |
google/jax#15492
Repro from JAX:
i.e., during optimization, XLA has changed a multiplication by a scalar bool to a select. This is incorrect because it does not have correct inf/nan semantics.
0 * inf
should benan
, not0
.The text was updated successfully, but these errors were encountered: