Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

XLA incorrectly optimizes scalar bool*float to select() #2438

Open
hawkinsp opened this issue Apr 10, 2023 · 9 comments
Open

XLA incorrectly optimizes scalar bool*float to select() #2438

hawkinsp opened this issue Apr 10, 2023 · 9 comments
Assignees
Labels
bug Something isn't working

Comments

@hawkinsp
Copy link
Member

google/jax#15492

Repro from JAX:

In [1]: print(jax.jit(lambda x, y: x*y).lower(f, inf).as_text(dialect="hlo"))
HloModule jit__lambda_, entry_computation_layout={(pred[],f32[])->f32[]}

ENTRY main.5 {
  Arg_0.1 = pred[] parameter(0), sharding={replicated}
  convert.3 = f32[] convert(Arg_0.1)
  Arg_1.2 = f32[] parameter(1), sharding={replicated}
  ROOT multiply.4 = f32[] multiply(convert.3, Arg_1.2)
}

In[2]: print(jax.jit(lambda x, y: x*y).lower(f, inf).compile().as_text())
HloModule jit__lambda_, entry_computation_layout={(pred[],f32[])->f32[]}, allow_spmd_sharding_propagation_to_output={true}

ENTRY %main.5 (Arg_0.1: pred[], Arg_1.2: f32[]) -> f32[] {
  %Arg_0.1 = pred[] parameter(0), sharding={replicated}
  %Arg_1.2 = f32[] parameter(1), sharding={replicated}
  %constant.1 = f32[] constant(0)
  ROOT %select = f32[] select(pred[] %Arg_0.1, f32[] %Arg_1.2, f32[] %constant.1), metadata={op_name="jit(<lambda>)/jit(main)/mul" source_file="<ipython-input-3-2c0439f60e63>" source_line=1}
}

i.e., during optimization, XLA has changed a multiplication by a scalar bool to a select. This is incorrect because it does not have correct inf/nan semantics. 0 * inf should be nan, not 0.

@jakevdp
Copy link
Contributor

jakevdp commented Apr 10, 2023

Internal tracking bug for this issue: b/245348010, related to an earlier report: google/jax#12233

@reedwm
Copy link
Member

reedwm commented Apr 11, 2023

Fixing this has a significant negative performance impact on TPUs. Additionally, according to the internal bug, many users rely on the fact that 0 * inf results in 0, which is bad but makes fixing this harder.

Because of this, I don't think this is worth fixing.

@soraros
Copy link

soraros commented Apr 13, 2023

@reedwm Those are fair concerns, but I'd argue that this is "unsafe optimisation" at best (just miscompilation at worst), and would better be put behind a flag if straight-up fixing it is not possible.

@reedwm
Copy link
Member

reedwm commented Apr 13, 2023

XLA in general does optimizations assuming floating-point values will not overflow or be NaN. These are ""unsafe optimisations" in the sense that XLA's behavior matches gcc's unsafe-math-optimizations optimizer flag.

We could add a no-unsafe-math-optimizations flag to XLA, but it would take a lot of effort to implement and would make things very slow, at least on the GPU. We could instead add a flag to disable this specific scenario (optimizing bool*float to select). I don't think it's worth it, unless there is a very compelling reason why this particular optimizations should be disabled in some cases. /CC @cheshire

@soraros do you know of a particular case where this optimization causes issues in practice, or is this more of a general concern about unsafe optimizations?

@soraros
Copy link

soraros commented Apr 14, 2023

@reedwm My concerns are more on the general side, as how much semantic preservation guarantee does XLA provide given a piece of, say, StableHLO.

XLA in general does optimizations assuming floating-point values will not overflow or be NaN.

This is of course expected, but how does it apply in the case of bool * float which is not even representable directly in HLO? And I wonder what's the significance of this particular optimisation? It hits differently than some other "unsafe" algebraic simplification like fused a * b + c not being entirely IEEE conformant.

Another question might be, for a frontend like JAX, what's the suggested way to generate IR that defeats this optimisation? Heterogeneous multiplication in JAX is currently implemented roughly as

def mul(a, b):
  dtype = dtypes.join(a.dtype, b.dtype)
  if a.dtype < new_dtype:
    a = stablehlo.convert(a, dtype)
  if b.dtype < new_dtype:
    b = stablehlo.convert(b, dtype)
  return stablehlo.multiply(a, b)

which generates for i1 x f32

module @mul {
  func.func public @main(%arg0: tensor<i1>, %arg1: tensor<f32>) -> tensor<f32> {
    %0 = stablehlo.convert %arg0 : (tensor<i1>) -> tensor<f32>
    %1 = stablehlo.multiply %0, %arg1 : tensor<f32>
    return %1 : tensor<f32>
  }
}

I can think of one way which works with the current version of XLA that is (I hope you would agree with me) quite ugly:

module @mul_bool_float {
  func.func public @main(%arg0: tensor<3xi1>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
    %0 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %1 = stablehlo.broadcast_in_dim %0, dims = [] : (tensor<f32>) -> tensor<3xf32>
    %2 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %3 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor<f32>) -> tensor<3xf32>
    %4 = stablehlo.select %arg0, %1, %3 : tensor<3xi1>, tensor<3xf32>
    %5 = stablehlo.multiply %4, %arg1 : tensor<3xf32>
    return %5 : tensor<3xf32>
  }
}

@reedwm
Copy link
Member

reedwm commented Apr 14, 2023

how does it apply in the case of bool * float which is not even representable directly in HLO?

As you mentioned, the bool is converted to float in frontends like JAX. Saying "bool * float" is shorthand for "taking a bool tensor, casting to to float32, and multiplying with another float32 tensor`.

I agree that this optimization "hits differently" than some other unsafe optimizations, because the input has a NaN which is not propagated to the output. I don't know of any other optimizations that drop NaNs (they might exist, but I couldn't find anything with a quick search), but there are optimizations that introduce NaN/Inf where they shouldn't. Because fixing this would slow down TPUs, I don't think it's worth fixing since it's unlikely to negatively affect any models. @blakehechtman, do you have any opinions on whether this optimization should be done?

To create IR to defeat this optimization, your mul_bool_float function would work, but is slightly risky as XLA in the future might recognize %4 is equivalent to casting %arg0 to float. I would defeat it as follows (using JAX instead of StableHLO for conciseness)

def multiply_bool_and_float(b, f):
  result = b * f
  return jnp.where(jnp.isfinite(f), result, jnp.nan)

@burmako it might be worth documenting in the StableHLO spec that the precise numeric behavior of all floating-point operations is implementation-dependent, and that optimizations may introduce or drop Infs/NaNs in places where they would not occur without optimizations.

@soraros
Copy link

soraros commented Apr 16, 2023

@reedwm Sorry if I sound unreasonably paranoid or stubborn, and thank you for your replies.

Saying bool * float is shorthand for "taking a bool tensor, casting to float32, and multiplying with another float32 tensor".

I think that's exactly why that's troublesome. Suppose I have a function f, which returns a float a, and a function g which consumes a by doing some computation on it including an a * b for some float b. Implementation of f might use boolean arithmetics internally and convert the result to float on return. And per your reply

... but is slightly risky as XLA in the future might recognise %4 is equivalent to casting %arg0 to float.

there is no good way to program defensively on the producer (f in this case) side to defeat this optimisation. I'm sure inserting OptimizationBarrierOp doesn't count, right? And even with full control of the consumer (g), looking at your multiply_bool_and_float, it still feels obscure and arguably hackish.

Because fixing this would slow down TPUs, ...

Could we at least disable them on other platforms?

@sdasgup3
Copy link
Member

sdasgup3 commented Apr 19, 2023

it might be worth documenting in the StableHLO spec that the precise numeric behavior of all floating-point operations is implementation-dependent, and that optimizations may introduce or drop Infs/NaNs in places where they would not occur without optimizations.

@reedwm

I've had some more time to think about the proposed changes to the StableHLO spec, and I am afraid to say that I'm not convinced that allowing such aggressive optimizations, is the best approach. One of my concerns is that this could equally enable conformant implementations other than XLA to do mis-compilations. Another concern is (which may be applicable to allowing any aggressive/unsafe optimization to be documented in the spec): In the event that such an optimization is dropped or supported under a special flag, which btw are all viable options, how should that event be reflected in the spec: drop it in the absence of use-case or let the other implementations interpret it their own way?

Anyway, I'm open to discussing this further, and I'm happy to help explore other options.

\cc @burmako

@reedwm
Copy link
Member

reedwm commented Apr 19, 2023

@soraros I agree multiply_bool_and_float is hackish, but we don't know of any use cases where defeating this optimization is necessary. If a user really wants to defeat it, they have a hacky solution, but I don't think it's necessary to consider better solutions unless we know of a use case where a solution is necessary.

As for disabling this optimization only on non-TPU platforms: I don't like this because it makes behavior among different devices inconsistent. Since some users (incorrectly) rely on this optimization, if the optimization was only done on the TPU, it would be difficult for such users to migrate to other devices.

@sdasgup3 this wouldn't be considered a mis-compilation if the spec documents it is OK. As for having a special flag to enable/disable this optimization: the spec doesn't need to mention any flags at all. It just would need to state that operations which normally would return NaN by the IEEE-754 spec are free to return any other value instead, and which value is returned is implementation-defined.

Anyway, it's not up to me whether we disable this optimization on XLA-TPU. I don't think XLA-TPU developers can be persuaded to take a performance hit by disabling this optimization, even if the StableHLO spec ends up stating the optmization is disallowed, but @blakehechtman please give your thoughts here. Since I want XLA-GPU to be consistent with XLA-TPU, I'm unwilling to make this change in XLA-GPU as well. I don't think we should add any flags to disable this behavior unless there are non-hypothetical models which require the optimization to be disabled.

@penpornk penpornk added the bug Something isn't working label Feb 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

6 participants