In [45]:
import jax
import jax.numpy as jnp

In [None]:
@jax.jit
def main():
    key = jax.random.PRNGKey(1)

    puzzles = jnp.array([0, 1, 2, 3, 4, 5])
    filtered_puzzles = jnp.where(puzzles > 3)
    choice = jax.random.choice(key, filtered_puzzles, shape=())
    return choice


main()

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[]
The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations.
The error occurred while tracing the function main at /var/folders/q4/2lsmb6qd1ks8137720rg8fz80000gn/T/ipykernel_37873/3900357575.py:1 for jit. This value became a tracer due to JAX operations on these lines:

  operation a[35m:i32[6][39m = convert_element_type[new_dtype=int32 weak_type=False] b
    from line /var/folders/q4/2lsmb6qd1ks8137720rg8fz80000gn/T/ipykernel_37873/3900357575.py:5:14 (main)

  operation a[35m:i32[][39m = convert_element_type[new_dtype=int32 weak_type=False] b
    from line /var/folders/q4/2lsmb6qd1ks8137720rg8fz80000gn/T/ipykernel_37873/3900357575.py:6:33 (main)

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

The above code results in a `ConcretizationTypeError` because the size of the filtered_puzzles array is variable. But Jax needs to know the shapes of all the arrays it works with in advance.


We fix it by doing this:


In [44]:
@jax.jit
def main():
    key = jax.random.PRNGKey(1)

    puzzles = jnp.array([0, 1, 2, 3, 4, 5])
    is_valid = puzzles > 3
    choice = jax.random.choice(key, puzzles, shape=(), p=is_valid)
    return choice


main()


Array(5, dtype=int32)