In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from jax.tree_util import register_pytree_node
from jax import random

### Problem: Create a random array based on size parameter

In [15]:
def random_initial_weights(size, r_key):
    return (random.normal(r_key, shape=(size, 1), dtype=jnp.float32) - 0.5) * 0.001

key = random.PRNGKey(0)
random_initial_weights(4, key)

DeviceArray([[ 0.00131609],
             [-0.00125489],
             [-0.00016011],
             [-0.00103484]], dtype=float32)

In [3]:
jit(random_initial_weights)(4, key)

TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, 1).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.

### Solution1 - Tell Jit to ignore the size parameter

In [17]:
optimized_random_initial_weights = jit(random_initial_weights, static_argnums=0)
optimized_random_initial_weights(4, key)

DeviceArray([[ 0.00131609],
             [-0.00125489],
             [-0.00016011],
             [-0.00103484]], dtype=float32)

In [20]:
%timeit -n 100 random_initial_weights(4, key).block_until_ready()

460 µs ± 108 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [21]:
%timeit -n 100 optimized_random_initial_weights(4, key).block_until_ready()

57 µs ± 29 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


### Problem: a condition is used and JIT doesn't work with conditions

In [39]:
def _relu(x):
    if x < 0:
        return 0.0
    else:
        return x

(_relu(1.0),_relu(-5.0),_relu(5.0))
    

(1.0, 0.0, 5.0)

In [64]:
relu = jit(_relu)

(relu(1.0),relu(-5.0),relu(5.0))

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function. 
While tracing the function _relu at <ipython-input-39-569c0715d112>:1, transformed by jit., this concrete value was not available in Python because it depends on the value of the arguments to _relu at <ipython-input-39-569c0715d112>:1, transformed by jit. at flattened positions [0], and the computation of these values is being staged out (that is, delayed rather than executed eagerly).
 (https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError)

In [65]:
relu_grad = grad(_relu)

(relu_grad(1.0), relu_grad(-5.0), relu_grad(5.0))

(array(1., dtype=float32), array(0., dtype=float32), array(1., dtype=float32))

### Solution: Avoid conditions

In [66]:
def _relu(x):
    return jnp.max(jnp.array([x, 0.0]))
    
relu = jit(_relu)   # No condition - can compile!
(relu(1.0),relu(-5.0),relu(5.0))

(DeviceArray(1., dtype=float32),
 DeviceArray(0., dtype=float32),
 DeviceArray(5., dtype=float32))

### Problem: loop requires evaluation of value

In [67]:
def _add2(X, n):
    result = []
    for i in range(n):
        result.append(X[i] * 2)
    return jnp.array(result)

x = jnp.array([0, 1, 2])

_add2(x, 3)

DeviceArray([0, 2, 4], dtype=int32)

In [56]:
add2 = jit(_add2)
add2(x, 2)

TracerIntegerConversionError: The __index__() method was called on the JAX Tracer object Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> (https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerIntegerConversionError)

### Solution: bind to some value that can be taken from the abstract shape

In [75]:
def _add2(X):
    result = []
    print(X.shape[0])
    for i in range(X.shape[0]): # the actual value for shape (on X) is taken from the input
        result.append(X[i] * 2)
    return jnp.array(result)

x = jnp.array([0, 1, 2, 3])
_add2(x)

4


DeviceArray([0, 2, 4, 6], dtype=int32)

In [71]:
add2 = jit(_add2)
add2(x)

DeviceArray([0, 2, 4, 6], dtype=int32)