In [1]:
import jax
import numpy as np
import jax.numpy as jnp

from jax import jit
from jax import lax

In [2]:
def some_computation(x):
    return x + 2 * x*x + 3 * x*x*x

In [3]:
x_np = np.random.normal(size = (10000, 10000)).astype(np.float32)

%timeit -n5 some_computation(x_np)

5 loops, best of 5: 465 ms per loop


In [6]:
x_jax = jax.random.normal(jax.random.PRNGKey(0), (10000, 10000), dtype = jnp.float32)

%timeit -n5 some_computation(x_jax)

The slowest run took 6.31 times longer than the fastest. This could mean that an intermediate result is being cached.
5 loops, best of 5: 5.26 ms per loop


In [7]:
some_computation_jax = jit(some_computation)

%timeit -n5 some_computation_jax(x_jax).block_until_ready()

5 loops, best of 5: 3.79 ms per loop


In [8]:
@jit
def some_computation_jax_decorated(x):
    return x + 2 * x*x + 3 * x*x*x

In [9]:
%timeit -n5 some_computation_jax_decorated(x_jax).block_until_ready()

5 loops, best of 5: 3.77 ms per loop


# Tracer Objects

Notice that the print statements execute, but rather than printing the data we passed to the function, though, it prints tracer objects that stand-in for them.

These tracer objects are what jax.jit uses to extract the sequence of operations specified by the function. Basic tracers are stand-ins that encode the *shape* and *dtype* of the arrays, but are agnostic to the values. This recorded sequence of computations can then be efficiently applied within XLA to new inputs with the same shape and dtype, without having to re-execute the Python code.

When we call the compiled function again on matching inputs, no re-compilation is required and nothing is printed because the result is computed in compiled XLA rather than in Python.

In [10]:
@jit
def some_function(x, y):

  print('Running some_fn()')

  print(f'x = {x}')
  print(f'y = {y}')

  result = jnp.dot(x, y)

  print(f'result = {result}')

  return result

In [11]:
# Inputs shape both (10000, 10000)

some_function(x_jax, x_jax.T)

Running some_fn()
x = Traced<ShapedArray(float32[10000,10000])>with<DynamicJaxprTrace(level=0/1)>
y = Traced<ShapedArray(float32[10000,10000])>with<DynamicJaxprTrace(level=0/1)>
result = Traced<ShapedArray(float32[10000,10000])>with<DynamicJaxprTrace(level=0/1)>


DeviceArray([[ 1.00604980e+04, -9.98899002e+01, -1.03874687e+02, ...,
              -1.23548485e+02,  4.76367607e+01, -1.07481445e+02],
             [-9.98899002e+01,  9.81833301e+03,  1.29423080e+02, ...,
               2.41555939e+01, -5.59050674e+01, -1.96277191e+02],
             [-1.03874687e+02,  1.29423080e+02,  1.01703213e+04, ...,
              -2.91882648e+01, -1.04302149e+01,  3.90430717e+01],
             ...,
             [-1.23548485e+02,  2.41555939e+01, -2.91882648e+01, ...,
               1.00786758e+04, -5.32248650e+01,  4.03949499e+00],
             [ 4.76367607e+01, -5.59050674e+01, -1.04302149e+01, ...,
              -5.32248650e+01,  9.87220312e+03, -1.78523216e+01],
             [-1.07481445e+02, -1.96277191e+02,  3.90430717e+01, ...,
               4.03949499e+00, -1.78523216e+01,  1.01478525e+04]],            dtype=float32)

In [12]:
# Inputs shape both (10000, 10000)

some_function(x_jax.T, x_jax)

DeviceArray([[ 1.01733340e+04,  1.20691864e+02,  2.18645535e+01, ...,
               1.54326508e+02, -3.17743263e+01, -3.51457520e+01],
             [ 1.20691864e+02,  9.92143555e+03,  2.49563065e+01, ...,
              -1.72294632e+02,  1.29981737e+01, -7.11319494e+00],
             [ 2.18645535e+01,  2.49563065e+01,  1.01241816e+04, ...,
              -2.22840881e+01, -6.63894424e+01, -4.59337234e+00],
             ...,
             [ 1.54326508e+02, -1.72294632e+02, -2.22840881e+01, ...,
               9.96794141e+03, -4.44877481e+00, -2.06930981e+01],
             [-3.17743263e+01,  1.29981737e+01, -6.63894424e+01, ...,
              -4.44877481e+00,  1.00113545e+04,  1.69023773e+02],
             [-3.51457520e+01, -7.11319494e+00, -4.59337234e+00, ...,
              -2.06930981e+01,  1.69023773e+02,  1.01222676e+04]],            dtype=float32)

In [13]:
x_jax_100 = jax.random.normal(jax.random.PRNGKey(0), (100, 100), dtype = jnp.float32)

# Inputs shape both (100, 100), same dtype as earlier compiled function

some_function(x_jax_100.T, x_jax_100)

Running some_fn()
x = Traced<ShapedArray(float32[100,100])>with<DynamicJaxprTrace(level=0/1)>
y = Traced<ShapedArray(float32[100,100])>with<DynamicJaxprTrace(level=0/1)>
result = Traced<ShapedArray(float32[100,100])>with<DynamicJaxprTrace(level=0/1)>


DeviceArray([[130.1622   , -11.1545315,   6.395634 , ...,  -0.6379199,
               -1.0903095,  -3.4633865],
             [-11.1545315,  81.10213  ,  -8.001425 , ...,   7.4121494,
               18.26979  ,  13.440549 ],
             [  6.395634 ,  -8.001425 , 107.0456   , ...,   4.493244 ,
                9.876643 ,  15.961442 ],
             ...,
             [ -0.6379199,   7.4121494,   4.493244 , ..., 103.51143  ,
                1.7205739,   6.656181 ],
             [ -1.0903095,  18.26979  ,   9.876643 , ...,   1.7205739,
               99.39847  ,   6.9601545],
             [ -3.4633865,  13.440549 ,  15.961442 , ...,   6.656181 ,
                6.9601545,  89.10493  ]], dtype=float32)

In [14]:
y_jax_100 = jax.random.normal(jax.random.PRNGKey(1), (100, 100), dtype = jnp.float32)

# Inputs shape both (100, 100), same dtype as earlier compiled function

some_function(y_jax_100.T, y_jax_100)

DeviceArray([[ 85.39795  ,   6.7500253, -19.024044 , ...,  11.312255 ,
               17.6198   ,   4.2592945],
             [  6.7500253,  97.65387  ,  -4.9124317, ...,  11.003968 ,
               15.469701 ,   7.7840524],
             [-19.024044 ,  -4.9124317, 119.65793  , ..., -10.738948 ,
               27.536476 ,   7.1068816],
             ...,
             [ 11.312255 ,  11.003968 , -10.738948 , ..., 109.41342  ,
               15.596954 ,   4.2096057],
             [ 17.6198   ,  15.469701 ,  27.536476 , ...,  15.596954 ,
              111.3831   ,  34.37196  ],
             [  4.2592945,   7.7840524,   7.1068816, ...,   4.2096057,
               34.37196  ,  92.73889  ]], dtype=float32)

In [23]:
x_jax_100_int = jnp.eye(100, dtype = jnp.int32)

# Inputs shape both (100, 100), dtype is different int32

some_function(x_jax_100_int.T, x_jax_100_int)

DeviceArray([[1, 0, 0, ..., 0, 0, 0],
             [0, 1, 0, ..., 0, 0, 0],
             [0, 0, 1, ..., 0, 0, 0],
             ...,
             [0, 0, 0, ..., 1, 0, 0],
             [0, 0, 0, ..., 0, 1, 0],
             [0, 0, 0, ..., 0, 0, 1]], dtype=int32)

# Pure Functions

According to [Wikipedia](https://en.wikipedia.org/wiki/Pure_function), a function is pure if:
1. The function returns the same values when invoked with the same inputs
2. There are no side effects observed on a function call

JAX [pure functions](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions):

- No control flow statements that depend on values
- It cannot use or change global state (variables outside its scope, global variables)
- It cannot have an I/O stream - so no printing, asking for input, or accessing the time
- It cannot have a mutable function as an argument (which a concurrent process could modify)


If you use JIT with impure functions, in some cases you will get some kind of observable error and you'll be able to fix your code. But it is also possible that JITing impure functions silently gives you wrong results

# Impure Functions: I/O

Pure functions can be impure because of the presence of input output statements

The side-effects i.e print('Return input value at output') appear during the first run but are not printed for subsequent runs

The external stream that your function depends on is an external state. 
Your function is still dependent on an external state. The print statement! It is using the standard output stream to print. What if the stream isn't available on the subsequent calls for whatsoever reason? That will violate the first principle of "returning the same thing" when called with the same inputs.





In [None]:
def return_same_value(x):
  print('Return input value at output')   
  return x

jit(return_same_value)(2.)

Return input value at output


DeviceArray(2., dtype=float32, weak_type=True)

Subsequent runs with parameters of same type and shape may not show the side-effect
This is because JAX now invokes a cached compilation of the function

In [None]:
jit(return_same_value)(6.)

DeviceArray(6., dtype=float32, weak_type=True)

JAX re-runs the Python function when the type or shape of the argument changes and side effects reappear

In [None]:
jit(return_same_value)(6)

Return input value at output


DeviceArray(6, dtype=int32, weak_type=True)

# Impure Functions: Globals


   Defining a function that relies on the global variable g for
   computation
    

In [None]:
power = 5

def power_of(x):
    
    return x**power

In [None]:
x = 2

x_5 = power_of(x)

x_5

32

In [None]:
power = 10

x_10 = power_of(x)

x_10

1024

We will `JIT` the function so that it runs as a JAX transformed
function and not like a normal python function

In [None]:
power = 5

x_5 = jit(power_of)(x)

x_5

DeviceArray(32, dtype=int32, weak_type=True)

Expected value of x_10 below is 32 but 1024 is obtained

In [None]:
power = 10

x_10 = jit(power_of)(x)

x_10

DeviceArray(32, dtype=int32, weak_type=True)

When you `jit` your function, JAX tracing kicks in. On the first call, the results would be as expected, but on the subsequent function calls you will get the **`cached`** results unless:
1. The type of the argument has changed or
2. The shape of the argument has changed

Let's see it in action

In [None]:
x = 2.0

x_10 = jit(power_of)(x)

x_10

DeviceArray(1024., dtype=float32, weak_type=True)

In [None]:
x = jnp.array([2])

x_10 = jit(power_of)(x)

x_10

DeviceArray([1024], dtype=int32)

# Impure Functions: Iterators

We will take a very simple example to see the side effect. We will add numbers from `0 to 5` but in two different ways:

1. Passing an actual array of numbers to a function
2. Passing an **`iterator`** object to the same function

In [None]:
array_jax = jnp.arange(5)

array_jax

DeviceArray([0, 1, 2, 3, 4], dtype=int32)

In [None]:
lax.fori_loop(0, 5, lambda i, x: x + array_jax[i], 0)

DeviceArray(10, dtype=int32)

It is not recommended to use iterators in any JAX function you want to jit or in any control-flow primitive. The reason is that an iterator is a python object which introduces state to retrieve the next element. Therefore, it is incompatible with JAX functional programming model. 

In [None]:
iterator = iter(range(5))

print(lax.fori_loop(0, 5, lambda i, x: x + next(iterator), 0))

0


Why the result turned out to be zero in the second case?<br>
This is because an `iterator` introduces an **external state** to retrieve the next value.

# Pure functions with stateful objects
A Python function can be functionally pure even if it actually uses stateful objects internally, as long as it does not read or write external state:

In [None]:
def pure_uses_internal_state(array):
    array_list = []
    for i in range(len(array)):
        array_list.append(array[i])
    return array_list

array = jnp.arange(5)

jit(pure_uses_internal_state)(array)

[DeviceArray(0, dtype=int32),
 DeviceArray(1, dtype=int32),
 DeviceArray(2, dtype=int32),
 DeviceArray(3, dtype=int32),
 DeviceArray(4, dtype=int32)]

Second call to the fucntion with different value

In [None]:
array = jnp.arange(10)

jit(pure_uses_internal_state)(array)

[DeviceArray(0, dtype=int32),
 DeviceArray(1, dtype=int32),
 DeviceArray(2, dtype=int32),
 DeviceArray(3, dtype=int32),
 DeviceArray(4, dtype=int32),
 DeviceArray(5, dtype=int32),
 DeviceArray(6, dtype=int32),
 DeviceArray(7, dtype=int32),
 DeviceArray(8, dtype=int32),
 DeviceArray(9, dtype=int32)]

So, to keep things **pure**, remember not to use anything inside a function that depends on any **external state**, including the IO as well. If you do that, transforming the function would give you unexpected results, and you would end up wasting a lot of time debugging your code when the transformed function returns a cached result, which is ironical because pure functions are easy to debug