In [1]:
import jax
import numpy as np
import jax.numpy as jnp

from jax import jit
from jax import lax
# Don't fotget to set up the runtime type as "GPU"

In [2]:
def some_computation(x):
  return x + 2*x*x + 3*x*x*x

In [3]:
x_np = np.random.normal(size= (10_000, 10_000)).astype(np.float32)
%timeit -n5 some_computation(x_np)

621 ms ± 41.6 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)


In [4]:
x_jax = jax.random.normal(jax.random.PRNGKey(0), (10_000, 10_000), dtype=jnp.float32)
%timeit -n5 some_computation(x_jax).block_until_ready()

43.3 ms ± 23 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)


In [5]:
some_computation_jax = jit(some_computation)
%timeit -n5 some_computation_jax(x_jax).block_until_ready()

The slowest run took 6.32 times longer than the fastest. This could mean that an intermediate result is being cached.
6.84 ms ± 7.23 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)


In [6]:
%time some_computation_jax(x_jax).block_until_ready()

CPU times: user 1.11 ms, sys: 35 µs, total: 1.15 ms
Wall time: 3.93 ms


Array([[ -5.123139  ,  -2.5528057 ,   0.54589796, ...,  -0.91849977,
        -11.3729725 ,  -0.13426392],
       [ -0.38166204,  19.278793  ,   8.179087  , ...,  -0.11817677,
         11.692125  ,  -8.095479  ],
       [ 41.685886  ,  -0.82572985,   1.1550295 , ...,  -3.5590916 ,
         -0.72880155,  -0.252419  ],
       ...,
       [ 20.680542  ,  35.664925  ,   2.5015452 , ...,   8.285651  ,
          1.961968  ,  -0.2965622 ],
       [ -4.633772  ,   8.397433  ,   0.20800267, ...,   0.06732942,
          0.11885115,   0.0623237 ],
       [ -0.4510728 ,   2.6750205 ,  -0.13392928, ...,  -0.16722514,
          0.9809772 ,   6.075848  ]], dtype=float32)

In [7]:
@jit
def some_computation_jit_decorated(x):
  return x + 2*x*x + 3*x*x*x

In [8]:
%timeit -n5 some_computation_jit_decorated(x_jax).block_until_ready()

5.05 ms ± 2.73 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)


In [9]:
@jit
def some_function(x, y):
  print(f'x = {x}')
  print(f'y = {y}')
  result = jnp.dot(x, y)
  print(f'result = {result}')
  return result

In [10]:
# inputs shape both (10_000, 10_000)
# First JIT-run
some_function(x_jax, x_jax.T)

x = Traced<ShapedArray(float32[10000,10000])>with<DynamicJaxprTrace(level=1/0)>
y = Traced<ShapedArray(float32[10000,10000])>with<DynamicJaxprTrace(level=1/0)>
result = Traced<ShapedArray(float32[10000,10000])>with<DynamicJaxprTrace(level=1/0)>


Array([[ 1.00604980e+04, -9.98899002e+01, -1.03874687e+02, ...,
        -1.23548485e+02,  4.76367607e+01, -1.07481445e+02],
       [-9.98899002e+01,  9.81833301e+03,  1.29423080e+02, ...,
         2.41555939e+01, -5.59050674e+01, -1.96277191e+02],
       [-1.03874687e+02,  1.29423080e+02,  1.01703213e+04, ...,
        -2.91882648e+01, -1.04302149e+01,  3.90430717e+01],
       ...,
       [-1.23548485e+02,  2.41555939e+01, -2.91882648e+01, ...,
         1.00786758e+04, -5.32248650e+01,  4.03949499e+00],
       [ 4.76367607e+01, -5.59050674e+01, -1.04302149e+01, ...,
        -5.32248650e+01,  9.87220312e+03, -1.78523216e+01],
       [-1.07481445e+02, -1.96277191e+02,  3.90430717e+01, ...,
         4.03949499e+00, -1.78523216e+01,  1.01478525e+04]],      dtype=float32)

In [11]:
# Second JIT-run
# print statements are removed, because they are side-effects
some_function(x_jax, x_jax.T)

Array([[ 1.00604980e+04, -9.98899002e+01, -1.03874687e+02, ...,
        -1.23548485e+02,  4.76367607e+01, -1.07481445e+02],
       [-9.98899002e+01,  9.81833301e+03,  1.29423080e+02, ...,
         2.41555939e+01, -5.59050674e+01, -1.96277191e+02],
       [-1.03874687e+02,  1.29423080e+02,  1.01703213e+04, ...,
        -2.91882648e+01, -1.04302149e+01,  3.90430717e+01],
       ...,
       [-1.23548485e+02,  2.41555939e+01, -2.91882648e+01, ...,
         1.00786758e+04, -5.32248650e+01,  4.03949499e+00],
       [ 4.76367607e+01, -5.59050674e+01, -1.04302149e+01, ...,
        -5.32248650e+01,  9.87220312e+03, -1.78523216e+01],
       [-1.07481445e+02, -1.96277191e+02,  3.90430717e+01, ...,
         4.03949499e+00, -1.78523216e+01,  1.01478525e+04]],      dtype=float32)

In [12]:
# Run the same JIT compiled function with different shape inputs
x_jax_100 = jax.random.normal(jax.random.PRNGKey(0), (100, 100), dtype=jnp.float32)
some_function(x_jax_100, x_jax_100.T)

x = Traced<ShapedArray(float32[100,100])>with<DynamicJaxprTrace(level=1/0)>
y = Traced<ShapedArray(float32[100,100])>with<DynamicJaxprTrace(level=1/0)>
result = Traced<ShapedArray(float32[100,100])>with<DynamicJaxprTrace(level=1/0)>


Array([[ 88.81407  ,   2.9003427,  -5.972798 , ...,  -4.707271 ,
          1.6306962, -18.483522 ],
       [  2.9003427,  90.007965 , -19.595284 , ...,   1.6136569,
         11.12668  ,  -4.6877217],
       [ -5.972798 , -19.595284 ,  81.10071  , ...,   1.0915028,
        -25.524742 ,   8.158062 ],
       ...,
       [ -4.707271 ,   1.6136569,   1.0915028, ...,  98.916084 ,
         -9.319294 ,   3.4946725],
       [  1.6306962,  11.12668  , -25.524742 , ...,  -9.319294 ,
         87.156    ,  -1.4231719],
       [-18.483522 ,  -4.6877217,   8.158062 , ...,   3.4946725,
         -1.4231719, 101.68203  ]], dtype=float32)

In [13]:
# Another input with the same shape (but different name) is used as argument this time
y_jax_100 = jax.random.normal(jax.random.PRNGKey(1), (100, 100), dtype=jnp.float32)
some_function(x_jax_100, y_jax_100.T)

Array([[ -7.3482647 ,  -3.6499982 ,  10.739452  , ...,   4.888681  ,
         17.538935  ,   5.201276  ],
       [ -0.08233333,   2.2109277 ,  15.598539  , ..., -12.977024  ,
        -10.567622  ,  -4.1360354 ],
       [ -4.844488  ,   8.9012165 ,  10.684648  , ...,   0.39691117,
          2.0356302 ,  -2.8063076 ],
       ...,
       [ -7.4417977 ,  13.557502  , -15.566593  , ...,  16.8637    ,
         11.441824  ,  -9.865422  ],
       [ 13.538433  ,  -8.893979  ,   2.9191327 , ...,   1.848477  ,
          3.9006677 ,   8.601546  ],
       [-12.764776  ,  23.827639  ,  -4.7825894 , ...,  -9.058582  ,
         -4.3704195 ,   7.200335  ]], dtype=float32)

In [14]:
# JIT-compiler traces not only shape of input vectors, but also data types
# We are now using 'int32' arrays instead of 'float32' and the JIT-compiler will retrace the same function as for the first time
x_jax_100_int = jnp.eye(100, dtype=jnp.int32)
some_function(x_jax_100_int, x_jax_100_int.T)

x = Traced<ShapedArray(int32[100,100])>with<DynamicJaxprTrace(level=1/0)>
y = Traced<ShapedArray(int32[100,100])>with<DynamicJaxprTrace(level=1/0)>
result = Traced<ShapedArray(int32[100,100])>with<DynamicJaxprTrace(level=1/0)>


Array([[1, 0, 0, ..., 0, 0, 0],
       [0, 1, 0, ..., 0, 0, 0],
       [0, 0, 1, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 1, 0, 0],
       [0, 0, 0, ..., 0, 1, 0],
       [0, 0, 0, ..., 0, 0, 1]], dtype=int32)

In [15]:
# Impure functions and JIT: I/O streams
def return_same_value(x):
  print('Return input value at output')
  return x

jit(return_same_value)(2.)

Return input value at output


Array(2., dtype=float32, weak_type=True)

In [16]:
jit(return_same_value)(6.)

Array(6., dtype=float32, weak_type=True)

In [17]:
jit(return_same_value)(6)

Return input value at output


Array(6, dtype=int32, weak_type=True)

In [18]:
# Impure functions and JIT: Global state
power = 5

def power_of(x):
  return x**power

In [19]:
x_5 = jit(power_of)(2)
x_5

Array(32, dtype=int32, weak_type=True)

In [20]:
power = 10
x_10 = jit(power_of)(2)
x_10

Array(32, dtype=int32, weak_type=True)

Result is wrong! Because the 'power' is a global variable and when the function was JIT-compiled for the first time, power was set to 5. When you JIT-compile the same function with a different power value, it uses the old power value!

In [21]:
# IMPORTANT! Here we change the input from int (2) to float (2.0), so forces JIT to recompile the same function
# power (global variable) was set to 10 recently. Since the function will be recompiled (re-JIT-ed), the correct value of 'power' will be used!
x_10 = jit(power_of)(2.0)
x_10

Array(1024., dtype=float32, weak_type=True)

In [22]:
x_10 = jit(power_of)(jnp.array([2]))
x_10

Array([1024], dtype=int32)

In [23]:
# Impure functions and JIT: Iterators
# In Python, iterators are stateful objects
array_jax = jnp.arange(5)
array_jax

Array([0, 1, 2, 3, 4], dtype=int32)

In [24]:
lax.fori_loop(0, 5, lambda i, x: x + array_jax[i], 0)

Array(10, dtype=int32)

In [26]:
iterator = iter(range(5)) # iterator is a stateful object!

print(lax.fori_loop(0, 5, lambda i, x: x + next(iterator), 0)) # because of the iterator, lambda function is no longer a pure function

0


In [33]:
def pure_function_with_internal_state(array):
    print('fresh JIT-compilation!')
    array_list = [] # internal state
    for i in range(len(array)):
        array_list.append(array[i])
    return array_list

In [35]:
array = jnp.arange(5)
jit(pure_function_with_internal_state)(array)

[Array(0, dtype=int32),
 Array(1, dtype=int32),
 Array(2, dtype=int32),
 Array(3, dtype=int32),
 Array(4, dtype=int32)]

In [36]:
array = jnp.arange(10)
jit(pure_function_with_internal_state)(array)

fresh JIT-compilation!


[Array(0, dtype=int32),
 Array(1, dtype=int32),
 Array(2, dtype=int32),
 Array(3, dtype=int32),
 Array(4, dtype=int32),
 Array(5, dtype=int32),
 Array(6, dtype=int32),
 Array(7, dtype=int32),
 Array(8, dtype=int32),
 Array(9, dtype=int32)]

### Jaxprs (Jax expressions)

In [37]:
import jax
import jax.numpy as jnp

from jax import jit, grad, random

In [38]:
def relu(x):
    return jnp.maximum(0.0, x)

In [39]:
print(jax.make_jaxpr(relu)(5.))

{ lambda ; a:f32[]. let b:f32[] = max 0.0 a in (b,) }


In [40]:
# This is an example to understand why Jax JIT-function doesn't work with global variables
input_list = []

def sigmoid(x):
    global input_list
    input_list.append(x)
    res = 1.0 / (1.0 + jnp.exp(-x))
    return res

print(jax.make_jaxpr(sigmoid)(5.0))

{ lambda ; a:f32[]. let
    b:f32[] = neg a
    c:f32[] = exp b
    d:f32[] = add 1.0 c
    e:f32[] = div 1.0 d
  in (e,) }


### Control flow statements and JIT

In [41]:
def f(x):
    if x > 0:
        return 3 * x**3 + 2 * x**2 + 5 * x
    else:
        return 2 * x

jitted_fn = jax.jit(f)
jitted_fn(10)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function f at <ipython-input-41-acee613a3c6b>:1 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

In [42]:
@jit
def exp_func(x):
    return 3 * x**3 + 2 * x**2 + 5 * x

In [43]:
def f_inner_jitted(x):
    if x > 0:
        return jit(exp_func)(x)
    else:
        return 2 * x

f_inner_jitted(10)

Array(3250, dtype=int32, weak_type=True)