Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unexpected behavior when multiplying Jax boolean with jnp.inf #15492

Closed
aguerra7002 opened this issue Apr 8, 2023 · 4 comments
Closed

Unexpected behavior when multiplying Jax boolean with jnp.inf #15492

aguerra7002 opened this issue Apr 8, 2023 · 4 comments
Labels
bug Something isn't working XLA

Comments

@aguerra7002
Copy link

Description

Hello! I recently was struggling to find a bug in my code, when I realized the problem came from some weird behavior from Jax. Below, note the inconsistency in evaluation when mutliplying jnp.array(False) with arrays of length >1:

import jax.numpy as jnp

f = jnp.array(False)
t = jnp.array(True)

inf=jnp.inf
a = jnp.array([inf])
b = jnp.array([inf, inf]) 

# Scalar multiplication (what I used as a baseline for 'normal')
print( t * inf )  # -> Array(inf, dtype=float32)
print( f * inf )  # -> Array(0., dtype=float32)

# Array multiplication
print( t * a )  # normal behavior -> Array([inf], dtype=float32)
print( f * a )  # normal behavior -> Array([0.], dtype=float32)
print( t * b ) # normal behavior -> Array([inf, inf], dtype=float32)
print( f * b ) # unexpected behavior! -> Array([nan, nan], dtype=float32)

It seems like there is a low probability that this behavior is intentional, so I decided to open this issue. Thanks in advance for any help!

What jax/jaxlib version are you using?

jax 0.4.8, jaxlib 0.4.7

Which accelerator(s) are you using?

GPU

Additional system info

Python 3.9.12, using Ubuntu on WSL2

NVIDIA GPU info

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.65       Driver Version: 527.37       CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA GeForce ...  On   | 00000000:01:00.0 Off |                  N/A |
| N/A   53C    P8     2W /  N/A |   3850MiB /  4096MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A     11014      C   /python3.9                      N/A      |
+-----------------------------------------------------------------------------+
@aguerra7002 aguerra7002 added the bug Something isn't working label Apr 8, 2023
@soraros
Copy link

soraros commented Apr 9, 2023

I actually find it super weird that jnp.array(False) * jnp.inf gives 0.

@hawkinsp hawkinsp added the XLA label Apr 10, 2023
@hawkinsp
Copy link
Member

This looks like an incorrect optimization in XLA:

In [1]: print(jax.jit(lambda x, y: x*y).lower(f, inf).as_text(dialect="hlo"))
HloModule jit__lambda_, entry_computation_layout={(pred[],f32[])->f32[]}

ENTRY main.5 {
  Arg_0.1 = pred[] parameter(0), sharding={replicated}
  convert.3 = f32[] convert(Arg_0.1)
  Arg_1.2 = f32[] parameter(1), sharding={replicated}
  ROOT multiply.4 = f32[] multiply(convert.3, Arg_1.2)
}

In[2]: print(jax.jit(lambda x, y: x*y).lower(f, inf).compile().as_text())
HloModule jit__lambda_, entry_computation_layout={(pred[],f32[])->f32[]}, allow_spmd_sharding_propagation_to_output={true}

ENTRY %main.5 (Arg_0.1: pred[], Arg_1.2: f32[]) -> f32[] {
  %Arg_0.1 = pred[] parameter(0), sharding={replicated}
  %Arg_1.2 = f32[] parameter(1), sharding={replicated}
  %constant.1 = f32[] constant(0)
  ROOT %select = f32[] select(pred[] %Arg_0.1, f32[] %Arg_1.2, f32[] %constant.1), metadata={op_name="jit(<lambda>)/jit(main)/mul" source_file="<ipython-input-3-2c0439f60e63>" source_line=1}
}

@jakevdp
Copy link
Collaborator

jakevdp commented Apr 10, 2023

Possible duplicate: #12233

@hawkinsp
Copy link
Member

Closing as duplicate of #12233

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working XLA
Projects
None yet
Development

No branches or pull requests

4 participants