In [3]:
import jax
import jax.numpy as jnp

In Jax, we know that something like this code works perfectly fine:


In [1]:
def fn():
    count = 0
    for i in [0, 1, 2, 3, 4]:
        count += int(i)
    return count


jax.jit(fn)()

Array(10, dtype=int32, weak_type=True)

When Jax is tracing `fn` and sees the list `[0, 1, 2, 3, 4]`, it will "unroll" the loop, resulting in something like this:


In [2]:
def fn():
    count = 0
    count += int(0)
    count += int(1)
    count += int(2)
    count += int(3)
    count += int(4)
    return count


jax.jit(fn)()

Array(10, dtype=int32, weak_type=True)

What if instead of using a normal Python list, we use a Jax array?


In [4]:
def fn():
    count = 0
    for i in jnp.array([0, 1, 2, 3, 4]):
        count += int(i)
    return count


jax.jit(fn)()

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[]
The problem arose with the `int` function. If trying to convert the data type of a value, try using `x.astype(int)` or `jnp.array(x, int)` instead.
The error occurred while tracing the function fn at /var/folders/q4/2lsmb6qd1ks8137720rg8fz80000gn/T/ipykernel_33732/1016498956.py:1 for jit. This value became a tracer due to JAX operations on these lines:

  operation a[35m:i32[5][39m = convert_element_type[new_dtype=int32 weak_type=False] b
    from line /var/folders/q4/2lsmb6qd1ks8137720rg8fz80000gn/T/ipykernel_33732/1016498956.py:3:13 (fn)

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

We now get a `ConcretizationTypeError`. For some reason, `i` is now a Tracer, meaning the Python operation `int(i)` is illegal.


In [6]:
def fn():
    count = 0
    arr = jnp.array([0, 1, 2, 3, 4])
    print(arr)
    for i in arr:
        count += int(i)
    return count


jax.jit(fn)()

Traced<ShapedArray(int32[5])>with<DynamicJaxprTrace>


ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[]
The problem arose with the `int` function. If trying to convert the data type of a value, try using `x.astype(int)` or `jnp.array(x, int)` instead.
The error occurred while tracing the function fn at /var/folders/q4/2lsmb6qd1ks8137720rg8fz80000gn/T/ipykernel_33732/3720083227.py:1 for jit. This value became a tracer due to JAX operations on these lines:

  operation a[35m:i32[5][39m = convert_element_type[new_dtype=int32 weak_type=False] b
    from line /var/folders/q4/2lsmb6qd1ks8137720rg8fz80000gn/T/ipykernel_33732/3720083227.py:3:10 (fn)

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

In fact, when we try inspecting what `arr` is, we see that it is ShapedArray Tracer instead of a concrete array. That's unintuitive! The Jax array `arr` is hardcoded in the function. Why is it converted into a Tracer?

One hypothesis is that Jax will always treat data that "resides" on the GPU as a Tracer. But why?
