## What's JAX?

Let's begin by recalling what JAX is. JAX is an ML/DL library developed by Google. I'd explain JAX as follows: **J** stands for JIT compilation using XLA, **A** represents auto-differentiation, a modified version of Tensorflow's autograd, also utilizing XLA as well, and lastly **X** refers to XLA itself, which stands for Accelerated Linear Algebra, which is a program optimized for accelerated matrix computations to be run much faster on accelerators (GPU/TPU) than NumPy.

In [None]:
import numpy as np
from jax import grad, jit
from jax import lax
from jax import random
import jax
import jax.numpy as jnp

### Array updates: `x.at[idx].set(y)`


In [None]:
jax_arr = jnp.zeros((3,3), dtype=jnp.float32)
jax_arr

Array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]], dtype=float32)

In [None]:
jax_arr[1,:] = 1.

TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html

In [None]:
jax_arr.at[1,:].set(1.)

Array([[0., 0., 0.],
       [1., 1., 1.],
       [0., 0., 0.]], dtype=float32)

In [None]:
a = jnp.array([1,2,3,4,5])
a[::2]

Array([1, 3, 5], dtype=int32)

### Out-of-bounds indexing

In [None]:
jnp.arange(10)

Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)

In [None]:
jnp.arange(10)[11]

Array(9, dtype=int32)

In [None]:
jnp.arange(10).at[11].get(mode='fill', fill_value = jnp.nan)

ValueError: cannot convert float NaN to integer

In [None]:
# finer-grained control
jnp.arange(10.).at[11].get(mode='fill', fill_value = jnp.nan)

Array(nan, dtype=float32)

### JAX doens't like tuple or list

In [None]:
jnp.sum([1,2,3])

TypeError: sum requires ndarray or scalar arguments, got <class 'list'> at position 0.

In [None]:
jnp.sum(jnp.array([1,2,3]))

Array(6, dtype=int32)

### Review PNGs and State

`numpy` uses PRNG called the Mersenne Twister. The state of the Mersenne Twister is represented by a state vector consisting of 624 32bit unsigned integers. Each time being called, it updates a state vector ("consuming 2 of uint32s") and later can mess up entropy for end users.

In [None]:
print(np.random.random())
print(np.random.random())
print(np.random.random())

0.5488135039273248
0.7151893663724195
0.6027633760716439


In [None]:
np.random.seed(0)

In [None]:
_ = np.random.uniform()
rng_state = np.random.get_state(); rng_state

('MT19937',
 array([2443250962, 1093594115, 1878467924, 2709361018, 1101979660,
        3904844661,  676747479, 2085143622, 1056793272, 3812477442,
        2168787041,  275552121, 2696932952, 3432054210, 1657102335,
        3518946594,  962584079, 1051271004, 3806145045, 1414436097,
        2032348584, 1661738718, 1116708477, 2562755208, 3176189976,
         696824676, 2399811678, 3992505346,  569184356, 2626558620,
         136797809, 4273176064,  296167901, 3430730584,  331909803,
        1908676996, 1950065095,  604298543, 3615988338, 1570232852,
        1028209748, 1511467721, 2411887154, 4210753555, 3096762720,
         423429618,  659966766, 2937509307, 2222847265,  378636552,
        1142109618, 2509241601, 1521729757,  888533219,  250885260,
        2455816244, 4046047811, 1947467789, 1395351953, 2388948566,
         934627940,  194642258, 1429256273, 2139959677, 1543740405,
        1569613451, 4061840539, 2075690423,  824532376,  844152077,
        3218002536,  897315311,  823

In [None]:
for i in range(311):
  _ = np.random.uniform()
rng_state = np.random.get_state(); rng_state

('MT19937',
 array([1499117434, 2949980591, 2242547484, 1470907986,   68004624,
         613504879, 2170701638, 3606168244, 1313189820, 2904302179,
        3340054280, 2800779156, 3718152353, 1082918459, 1748036786,
        3125556887, 1246967947, 2050301915, 3440863170, 2306625137,
        2391836667, 1253663658, 2419038162, 3499839328, 3576356820,
        3828856986,  723946277, 1516277410, 1749873187, 2585175776,
        2103116091, 3761404950, 2177145536, 2190050649, 2604636580,
        1049507822, 3538272245, 2566586914, 3538170909, 4282737256,
        3260797503, 2387454175, 2226689230, 2256270485,  436199026,
        1447928333, 1300475185, 3910190296, 2621047601, 2432253395,
        3548512997, 3038311477, 3870448599, 4184179771,  331186464,
        1513235983, 1123184249, 1412176674,  974731669, 1184859182,
        3903198916, 1010728009, 1157972564, 1456817460, 4280740152,
        3287444695, 3162962129, 2065442163,  702491398, 2129714181,
        1271816637, 1310830189, 1626

In [None]:
_ = np.random.uniform()
rng_state = np.random.get_state(); rng_state

('MT19937',
 array([1499117434, 2949980591, 2242547484, 1470907986,   68004624,
         613504879, 2170701638, 3606168244, 1313189820, 2904302179,
        3340054280, 2800779156, 3718152353, 1082918459, 1748036786,
        3125556887, 1246967947, 2050301915, 3440863170, 2306625137,
        2391836667, 1253663658, 2419038162, 3499839328, 3576356820,
        3828856986,  723946277, 1516277410, 1749873187, 2585175776,
        2103116091, 3761404950, 2177145536, 2190050649, 2604636580,
        1049507822, 3538272245, 2566586914, 3538170909, 4282737256,
        3260797503, 2387454175, 2226689230, 2256270485,  436199026,
        1447928333, 1300475185, 3910190296, 2621047601, 2432253395,
        3548512997, 3038311477, 3870448599, 4184179771,  331186464,
        1513235983, 1123184249, 1412176674,  974731669, 1184859182,
        3903198916, 1010728009, 1157972564, 1456817460, 4280740152,
        3287444695, 3162962129, 2065442163,  702491398, 2129714181,
        1271816637, 1310830189, 1626

### JAX's PRNG

JAX's PRNG uses a modern Threefry counter-based PRNG that's splittable. It fork the PRNG state to generate new PRNGs with parallel stochastic generation.

In [None]:
from jax import random
key = random.PRNGKey(0)
key

Array([0, 0], dtype=uint32)

In [None]:
print(key)
print(random.normal(key, shape=(1,)))
print(random.normal(key, shape=(1,))) # Don't repeat

[0 0]
[-0.20584226]
[-0.20584226]


In [None]:
key, subkey = random.split(key); print(key)
print(random.normal(key, shape=(1,)))

[4146024105  967050713]
[0.14389051]


### Control flow

In [None]:
def f(x):
  if x < 3:
    return 3 * x ** 2
  else:
    return -4 * x

print(grad(f)(2.))
print(grad(f)(4.))

12.0
-4.0


#### `jit` constraints are still confusing

In [None]:
@jit
def f(x):
  for i in range(3):
    print(x) # 6, 12, 24
    x = 2 * x
  return x

print(f(3)) # I thought jit functions shouldn't change the variables bc it's functional?

Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
24


In [None]:
@jit
def g(x):
  y = 0.
  for i in range(x.shape[0]):
    y = y + x[i]
  return y

print(g(jnp.array([1.,2.,3.])))

6.0


In [None]:
@jit
def f(x):
  if x < 3:
    return 3 * x ** 2
  else:
    return -4 * x

f(2)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function f at <ipython-input-25-504f2c15c277>:1 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

In [None]:
def f(x):
  if x < 3:
    return 3 * x ** 2
  else:
    return -4 * x

f = jit(f, static_argnums=(0,))
f(2.) # what's the point of using jit in this case even though it works

Array(12., dtype=float32, weak_type=True)

In [None]:
def f(x,n):
  y=0
  for i in range(n):
    y = y + x[i]
  return y

f = jit(f, static_argnums=(1,))
f(jnp.array([2.,3.,4.]), 2)

Array(5., dtype=float32)

### `lax`


In [None]:
def cond(pred, true_fn, false_fn, operand):
  if pred:
    return true_fn(operand)
  else:
    return false_fn(operand)


In [None]:
from jax import lax

operand = jnp.array([0.])
lax.cond(True, lambda x:x+1, lambda x:x-1, operand)

Array([1.], dtype=float32)

In [None]:
lax.cond(False, lambda x:x+1, lambda x:x-1, operand)

Array([-1.], dtype=float32)

* `lax.select`
* `lax.switch`
* `lax.scan`
* `jnp.where`
* `jnp.piecewise`
* `jnp.select`

### `while_loop`

In [None]:
def while_loop(cond_fn, body_fn, init_val):
  val = init_val
  while cond_fn(val):
    val = body_fn(val)
  return val


In [None]:
init_val = 0

In [None]:
cond_fn = lambda x:x<10
body_fn = lambda x:x+1
lax.while_loop(cond_fn, body_fn, init_val)

Array(10, dtype=int32, weak_type=True)

### `fori_loop`

In [None]:
def fori_loop(start, stop, body_fn, init_val):
  val = init_val
  for i in range(start, stop):
    val = body_fn(val)
  return val

In [None]:
init_val = 0
start = 0
stop = 10
body_fn = lambda i, x: x+i
lax.fori_loop(start, stop, body_fn, init_val)

Array(45, dtype=int32, weak_type=True)

### JAX within transforms like jit, vmap, gad, etc require all output arrays and intermediate arrays to have static shape!

In [None]:
def nansum(x):
  mask = ~jnp.isnan(x)
  x_without_nans = x[mask] # changes the shape of x
  return x_without_nans.sum()

In [None]:
x = jnp.array([1.,2.,3.,4.,jnp.nan])
print(nansum(x))

10.0


In [None]:
jit(nansum)(x)

NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[5])

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError

In [None]:
def nansum_2(x):
  mask = ~jnp.isnan(x)
  return jnp.where(mask, x, 0).sum()

In [None]:
jit(nansum_2)(x)

Array(10., dtype=float32)

## More of JIT

* JAX transformation
* JAX compilation

JAX transformation and compilation are designed to work **only** on **functionally pure** Python functions.


JIT is used for a Python function to be executed efficiently in XLA.

It's pretty confusing at the beginning to know which will be considered as side-effects for @jit. I find it helpful to understand how `jit` abstracts the values in a function through `jaxpr`. For example, as we are seeing the below of `Traced<ShapedArray(int32[], weak_type=True)`, that's how the function is abstracted by @jit, and by specializing the type of the argument, like having a type constraints, we can run an optimized code in XLA.

The abstraction level of tracing has tradeoffs. The reason why JAX does tracing in its way is because we can reuse the function for other values once it's compiled (with same type, dimention) without implementing the function again, as it's stored in XLA.(Is this a right understanding?)

Unlike jax.jit being more strict, jax.grad's constraints are more relexed.


So, this works somehow.

```
def dim_x_2(x):
  if x.ndim == 2:
    return x*2
  else:
    return x

jit_log2 = jit(dim_x_2)
jit_log2(jnp.array([[1,2]])) # Array([[2, 4]], dtype=int32)

```
But the following returns an error. (It's an example from the JAX documentation)

```
def f(x):
  if x > 0: # x == 0 doesn't work either
    return x
  else:
    return 2 * x

f_jit = jax.jit(f)
f_jit(10)
```

So, the dimension checking condition works but if it doesn't work if the condition depends on the actual values in the argument.

```
def type_float(x):
  if type(x) == float:
    return x
  else:
    return x*2

jit_type_float = jit(type_float)
jit_type_float(1)
```

This worked as well. So in summary, `jit` works well with conditions if it only has to check the dimension or the type of the input, however, if it has to check the actual values in the input, it'll return error.

In [1]:
def print_side_effect(x):
  print("Executing function")
  return x

In [2]:
print(f"Just with Python: {print_side_effect(4)}")

Executing function
Just with Python: 4


In [3]:
from jax import jit
print(f"First JAX call: {jit(print_side_effect)(4.)}")

Executing function
First JAX call: 4.0


In [4]:
# I think the result is because during the first run, the JAX function is compiled to jaxpr and show the results only in jaxpr
print(f"Second JAX call: {jit(print_side_effect)(5.)}")

Second JAX call: 5.0


In [5]:
# But if we put a different type (or a shape)
print(f'Third JAX call with an array: {jit(print_side_effect)([6.])}')

Executing function
Third JAX call with an array: [Array(6., dtype=float32, weak_type=True)]


https://github.com/google/jax/issues/196

In [7]:
g = 0
def global_var_side_effect(x):
  print(f'g: {g}')
  return x + g

In [10]:
global_var_side_effect(4)

g: 0


4

In [11]:
g = 10
global_var_side_effect(5)

g: 10


15

In [12]:
g = 0
print(f'First JAX call: {jit(global_var_side_effect)(4.)}')

First JAX call: 4.0


In [13]:
print(jit(global_var_side_effect)(5.))

5.0


In [14]:
g = 10
# doesn't reflect a new g's value
print(f'Second JAX call: {jit(global_var_side_effect)(5.)}')

Second JAX call: 5.0


In [17]:
# but if we put another type, the function will be re-run
import jax.numpy as jnp
print(f'Third JAX call: {jit(global_var_side_effect)(jnp.array([10.]))}')

g: 10
Third JAX call: [20.]


In [21]:
g = 0
def global_kw_impure(x):
  global g
  g = x
  return x

In [19]:
global_kw_impure(5)

5

In [20]:
print(g)

5


In [22]:
print(f'First call: {jit(global_kw_impure)(5)}')

First call: 5


In [23]:
print(g)

Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>


In [26]:
print(f'Second call: {jit(global_kw_impure)(6)}')

Second call: 6


In [27]:
print(g)

Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>


In [36]:
def pure_internal_state(x):
  state = dict(even=0, odd = 0) # if the state is internally defined and motified, a function can be functionally pure
  for i in range(10):
    state['even' if i % 2 == 0 else 'odd'] += x
  return state['even']

In [38]:
print(jit(pure_internal_state)(5.))

25.0


It's not recommended to use iterator in any JAX function you want to `jit` or put in any control-flow primitive.

In [48]:
import jax.lax as lax
from jax import make_jaxpr

array = jnp.arange(10); array

Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)

In [42]:
print(lax.fori_loop(0,10,lambda i, x : x+array[i], 0))# looping over i, x is initialized to 0; so essentially adding numbers in array array

45


In [44]:
iterator = iter(range(10))
iterator # python object that retrieves next value

<range_iterator at 0x7a82ef6b1500>

In [46]:
lax.fori_loop(0,10, lambda i, x: x+next(iterator), 0)

Array(0, dtype=int32, weak_type=True)

In [58]:
def func11(arr, extra):
  ones = jnp.ones(arr.shape)
  def body(carry, aelems):
    ae1, ae2 = aelems
    return (carry + ae1*ae2 + extra, carry)
  return lax.scan(body, 0., (arr, ones)) #lax.scan is still confusing

In [51]:
make_jaxpr(func11)(jnp.arange(16), 5.)

{ lambda ; a:i32[16] b:f32[]. let
    c:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
    d:f32[] e:f32[16] = scan[
      jaxpr={ lambda ; f:f32[] g:f32[] h:i32[] i:f32[]. let
          j:f32[] = convert_element_type[new_dtype=float32 weak_type=False] h
          k:f32[] = mul j i
          l:f32[] = convert_element_type[new_dtype=float32 weak_type=False] g
          m:f32[] = add l k
          n:f32[] = convert_element_type[new_dtype=float32 weak_type=False] f
          o:f32[] = add m n
        in (o, g) }
      length=16
      linear=(False, False, False, False)
      num_carry=1
      num_consts=1
      reverse=False
      unroll=1
    ] b 0.0 a c
  in (d, e) }

In [59]:
func11(jnp.array([1.,2.]), 5.)

(Array(13., dtype=float32), Array([0., 6.], dtype=float32, weak_type=True))

In [61]:
def body(carry, aelems):
  ae1, ae2 = aelems
  return (carry + ae1*ae2 + 5., carry)

In [81]:
jnp.array([[1,2]]).ndim

2

In [86]:
type(1.) == float

True

In [87]:
def type_float(x):
  if type(x) == float:
    return x
  else:
    return x*2

jit_type_float = jit(type_float)
jit_type_float(1)

Array(2, dtype=int32, weak_type=True)

In [82]:
def dim_x_2(x):
  if x.ndim == 2:
    return x*2
  else:
    return x

jit_log2 = jit(dim_x_2)
jit_log2(jnp.array([[1,2]]))

Array([[2, 4]], dtype=int32)

In [68]:
from jax import jit
jit_log2 = jit(log2_if_rank_2)

In [69]:
jit_log2(jnp.array([1,2,3]))

Array([1, 2, 3], dtype=int32)

In [74]:
make_jaxpr(jit_log2)(jnp.array([1,2,3]))

{ lambda ; a:i32[3]. let
    b:i32[3] = pjit[
      name=log2_if_rank_2
      jaxpr={ lambda ; c:i32[3]. let  in (c,) }
    ] a
  in (b,) }

In [88]:
# you can jit only the computationally intensive part

@jit
def loop_i(prev_i):
  return prev_i + 1

def until_n(x, n):
  i = 0
  while i<n:
    i = loop_i(i)
  return x+i

until_n(10,20)

Array(30, dtype=int32, weak_type=True)

In [90]:
def f(x):
  if x > 0:
    return x
  else:
    return 2 * x

# jit(f)(10) doesn't work because it has to check the values in x

jit(f, static_argnums=0)(10)

Array(10, dtype=int32, weak_type=True)

In [91]:
def g(x,n):
  i = 0
  while i<n:
    i += 1
  return x+i

# jit(g)(10,20) returns an error as well because it has to check the value of n
jit(g, static_argnames=['n'])(10,20)

Array(30, dtype=int32, weak_type=True)

In [92]:
# we can use decorators

from functools import partial

@partial(jax.jit, static_argnames=['n'])
def g_decorated(x,n):
  i = 0
  while i<n:
    i += 1
  return x+i

g_decorated(10,20)

Array(30, dtype=int32, weak_type=True)

## Caching

It's important to understand the caching behavior of `jit`

Suppose we define a jit function `jitted_g = jax.jit(g)`. When I first invoke `jitted_g`, it'll get compiled, and the resulting XLA code will get cached. Subsequent calls of `jitted_g` will reuse the compile code. This is how jit makes up for the upfront compilation cost.

If we specify `static_argnums`, then the cached code will only be reused for the same values for that static argument. If anything changes of it, recompilation occurs.

## Manual Vectorization

In [93]:
# convolution of 2d arrays

x = jnp.arange(5); print(x)
w = jnp.array([2.,3.,4.])


[0 1 2 3 4]


In [95]:
def convolve(x,w):
  output = []
  for i in range(1,len(x)-1):
    output.append(jnp.dot(x[i-1:i+2],w))
  return jnp.array(output)

convolve(x,w)

Array([11., 20., 29.], dtype=float32)

In [101]:
# suppose we want to run this to a batch of `x`s and `w`s

xs = jnp.stack([x,x]); print(xs)
ws = jnp.stack([w,w]); print(ws)

[[0 1 2 3 4]
 [0 1 2 3 4]]
[[2. 3. 4.]
 [2. 3. 4.]]


In [97]:
# naive solution - not very efficient

def manually_batched_convolve(xs,ws):
  output = []
  for i in range(xs.shape[0]):
    output.append(convolve(xs[i], ws[i]))
  return jnp.stack(output)

manually_batched_convolve(xs,ws)

Array([[11., 20., 29.],
       [11., 20., 29.]], dtype=float32)

In [104]:
jnp.sum(xs[:,0:3]*ws, axis=1)

Array([11., 11.], dtype=float32)

In [105]:
def manually_vectorized_convolve(xs, ws):
  output = []
  for i in range(1, xs.shape[-1]-1):
    output.append(jnp.sum(xs[:,i-1:i+2]*ws, axis=1))
  print(output)
  return jnp.stack(output, axis=1)

In [106]:
manually_vectorized_convolve(xs,ws)

[Array([11., 11.], dtype=float32), Array([20., 20.], dtype=float32), Array([29., 29.], dtype=float32)]


Array([[11., 20., 29.],
       [11., 20., 29.]], dtype=float32)

## Automatic Vectorization

In [107]:
def convolve(x,w):
  output = []
  for i in range(1,len(x)-1):
    output.append(jnp.dot(x[i-1:i+2],w))
  return jnp.array(output)

jax.vmap(convolve)(xs,ws)

Array([[11., 20., 29.],
       [11., 20., 29.]], dtype=float32)

In [112]:
# vmaps in_axes and out_axes args are the location of the batch dimension in inputs and for outputs

xst = jnp.transpose(xs); print(xst) # in this case, the batch location is 1
wst = jnp.transpose(ws); print(wst)

jax.vmap(convolve, in_axes=1, out_axes=1)(xst,wst)

[[0 0]
 [1 1]
 [2 2]
 [3 3]
 [4 4]]
[[2. 2.]
 [3. 3.]
 [4. 4.]]


Array([[11., 11.],
       [20., 20.],
       [29., 29.]], dtype=float32)

In [113]:
jax.vmap(convolve, in_axes=1, out_axes=0)(xst, wst) # if we want batch dimension comes back to 0

Array([[11., 20., 29.],
       [11., 20., 29.]], dtype=float32)

In [114]:
# what if we have x batched, but want to use not batched weights

jax.vmap(convolve, in_axes=[0, None])(xs,w)

Array([[11., 20., 29.],
       [11., 20., 29.]], dtype=float32)

In [115]:
# we can jit vmap
jax.jit(jax.vmap(convolve, in_axes=1, out_axes=0))(xst, wst)

Array([[11., 20., 29.],
       [11., 20., 29.]], dtype=float32)