In [5]:
import jax
import jax.numpy as jnp
from functools import partial

### Original failing code:


In [1]:
schedule = jnp.array([10, 20, 30, 40, 50])
updates_per_schedule = 10


def train():
    def _meta_step(meta_state, _):
        count = meta_state

        # Ranges from 10 to 50
        num_steps = schedule[count // updates_per_schedule]

        def _env_step(env_state, _):
            return env_state, None

        env_state = jax.lax.scan(_env_step, None, None, num_steps)

        return count + 1, env_state

    count = 0
    count, _ = jax.lax.scan(_meta_step, count, None, updates_per_schedule * len(schedule))


jax.jit(train)()

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[]
The `length` argument to `scan` expects a concrete `int` value.
The error occurred while tracing the function _meta_step at /var/folders/q4/2lsmb6qd1ks8137720rg8fz80000gn/T/ipykernel_31168/2946125733.py:9 for scan. This concrete value was not available in Python because it depends on the value of the argument meta_state.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

### Fix attempt 1:


In [3]:
schedule = jnp.array([10, 20, 30, 40, 50])
updates_per_schedule = 10


def train():
    def _meta_step(meta_state, num_steps):
        count = meta_state

        def _env_step(env_state, _):
            return env_state, None

        env_state = jax.lax.scan(_env_step, None, None, num_steps)

        return count + 1, env_state

    count = 0
    count, _ = jax.lax.scan(
        _meta_step, count, jnp.repeat(schedule, updates_per_schedule), updates_per_schedule * len(schedule)
    )
    return count


jax.jit(train)()

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[]
The `length` argument to `scan` expects a concrete `int` value.
The error occurred while tracing the function _meta_step at /var/folders/q4/2lsmb6qd1ks8137720rg8fz80000gn/T/ipykernel_31168/1645366953.py:6 for scan. This concrete value was not available in Python because it depends on the value of the argument num_steps.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

### Fix attempt 2:


In [10]:
schedule = jnp.array([10, 20, 30, 40, 50])
updates_per_schedule = 10


def train():
    @partial(jax.jit, static_argnums=(1,))
    def _meta_step(meta_state, num_steps):
        count = meta_state

        def _env_step(env_state, _):
            return env_state, None

        env_state = jax.lax.scan(_env_step, None, None, num_steps)

        return count + 1, env_state

    count = 0
    for num_steps in jnp.repeat(schedule, updates_per_schedule):
        count, _ = _meta_step(count, int(num_steps))

    return count


jax.jit(train)()

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[]
The problem arose with the `int` function. If trying to convert the data type of a value, try using `x.astype(int)` or `jnp.array(x, int)` instead.
The error occurred while tracing the function train at /var/folders/q4/2lsmb6qd1ks8137720rg8fz80000gn/T/ipykernel_31168/3242141445.py:4 for jit. This value became a tracer due to JAX operations on these lines:

  operation a[35m:i32[5,10][39m = broadcast_in_dim[
  broadcast_dimensions=(0,)
  shape=(5, 10)
  sharding=None
] b
    from line /var/folders/q4/2lsmb6qd1ks8137720rg8fz80000gn/T/ipykernel_31168/3242141445.py:17:21 (train)

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

Ok, the above doesn't work because Jax arrays become Tracers when being jit compiled. This means `num_steps` in the for loop will be a Tracer, so we cannot apply `int` on it. Let's try using plain Python lists.


### Fix attempt 3:


In [62]:
schedule = [10, 20, 30, 40, 50]
updates_per_schedule = 10
schedule_list = (
    [10] * updates_per_schedule
    + [20] * updates_per_schedule
    + [30] * updates_per_schedule
    + [40] * updates_per_schedule
    + [50] * updates_per_schedule
)


@jax.jit
def train1():
    def _meta_step(meta_state, num_steps):
        count = meta_state

        def _env_step(env_state, _):
            return env_state, None

        env_state = jax.lax.scan(_env_step, None, None, num_steps)

        return count + 1, env_state

    count = 0
    for num_steps in schedule_list:
        count, _ = _meta_step(count, int(num_steps))

    return count


train1()

Array(50, dtype=int32, weak_type=True)

The above solution might get too expensive if `_meta_step` is expensive to compile. Which leads to our final attempt:


In [63]:
def train2():
    @partial(jax.jit, static_argnums=(1,))
    def _meta_step(meta_state, num_steps):
        count = meta_state

        def _env_step(env_state, _):
            return env_state, None

        env_state = jax.lax.scan(_env_step, None, None, num_steps)

        return count + 1, env_state

    count = 0
    for num_steps in schedule_list:
        count, _ = _meta_step(count, int(num_steps))

    return count


train2()

Array(50, dtype=int32, weak_type=True)

Let's compare performance of these two approaches.


In [64]:
%timeit train1()

2.69 μs ± 41.1 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [65]:
%timeit train2()


83.4 ms ± 171 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
