You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello! I recently was struggling to find a bug in my code, when I realized the problem came from some weird behavior from Jax. Below, note the inconsistency in evaluation when mutliplying jnp.array(False) with arrays of length >1:
import jax.numpy as jnp
f = jnp.array(False)
t = jnp.array(True)
inf=jnp.inf
a = jnp.array([inf])
b = jnp.array([inf, inf])
# Scalar multiplication (what I used as a baseline for 'normal')
print( t * inf ) # -> Array(inf, dtype=float32)
print( f * inf ) # -> Array(0., dtype=float32)
# Array multiplication
print( t * a ) # normal behavior -> Array([inf], dtype=float32)
print( f * a ) # normal behavior -> Array([0.], dtype=float32)
print( t * b ) # normal behavior -> Array([inf, inf], dtype=float32)
print( f * b ) # unexpected behavior! -> Array([nan, nan], dtype=float32)
It seems like there is a low probability that this behavior is intentional, so I decided to open this issue. Thanks in advance for any help!
What jax/jaxlib version are you using?
jax 0.4.8, jaxlib 0.4.7
Which accelerator(s) are you using?
GPU
Additional system info
Python 3.9.12, using Ubuntu on WSL2
NVIDIA GPU info
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.65 Driver Version: 527.37 CUDA Version: 12.0 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA GeForce ... On | 00000000:01:00.0 Off | N/A |
| N/A 53C P8 2W / N/A | 3850MiB / 4096MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| 0 N/A N/A 11014 C /python3.9 N/A |
+-----------------------------------------------------------------------------+
The text was updated successfully, but these errors were encountered:
Description
Hello! I recently was struggling to find a bug in my code, when I realized the problem came from some weird behavior from Jax. Below, note the inconsistency in evaluation when mutliplying
jnp.array(False)
with arrays of length >1:It seems like there is a low probability that this behavior is intentional, so I decided to open this issue. Thanks in advance for any help!
What jax/jaxlib version are you using?
jax 0.4.8, jaxlib 0.4.7
Which accelerator(s) are you using?
GPU
Additional system info
Python 3.9.12, using Ubuntu on WSL2
NVIDIA GPU info
The text was updated successfully, but these errors were encountered: