# JAX 101


Click the image below to read the post online.

<a target="_blank" href="https://www.machinelearningnuggets.com/what-is-jax
"><img src="https://digitalpress.fra1.cdn.digitaloceanspaces.com/mhujhsj/2022/07/logo.png" alt="Open in ML Nuggets"></a>

JAX is an open source Python package for numerical computation in with accelearators and XLA.

## Installing JAX



pip install jax

### Why use JAX?

- faster than NumPy
- consumes less memory and is convenient to use
- Other data science packages are built on top of it, for example Flax, Haiku etc

## Setting up TPUs on Google Colab

Ensure you change the run time to TPUs

In [1]:
import jax
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
jax.devices()

In [2]:
import jax.numpy as jnp 
import numpy as np

# Data Types in JAX

In [11]:
x = jnp.float32(1.25844) 

In [12]:
x

DeviceArray(1.25844, dtype=float32)

In [13]:
type(x)

jaxlib.xla_extension.DeviceArray

In [14]:
x = jnp.int32(45.25844) 

In [15]:
x

DeviceArray(45, dtype=int32)

## Ways to Create JAX NumPy Arrays

A NumPy array is a multidimensional array-like data structure

### np.arange()

In [16]:
jnp.arange(10)

DeviceArray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)

In [17]:
jnp.arange(0,10)

DeviceArray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)

In [18]:
jnp.arange(0,10,2)

DeviceArray([0, 2, 4, 6, 8], dtype=int32)

### Covert Python List to a NumPy Array

In [19]:
scores = [50,60,70,30,25,70]

In [None]:
scores_array = jnp.array(scores)

In [None]:
scores_array

DeviceArray([50, 60, 70, 30, 25, 70], dtype=int32)

In [None]:
type(scores_array)

jaxlib.xla_extension.DeviceArray

In [None]:
scores_array.ndim # the dimension of the array

1

In [None]:
scores_array.size # the number of items in the array

6

In [None]:
scores_array.dtype # the type of data in the array

dtype('int32')

In [None]:
jnp.unique(scores_array) # print unique items from the array

DeviceArray([25, 30, 50, 60, 70], dtype=int32)

In [None]:
scores_array

DeviceArray([50, 60, 70, 30, 25, 70], dtype=int32)

In [None]:
jnp.flip(scores_array) # reverse the array

DeviceArray([70, 25, 30, 70, 60, 50], dtype=int32)

In [None]:
jnp.sort(scores_array)

DeviceArray([25, 30, 50, 60, 70, 70], dtype=int32)

In [None]:
scores_array

DeviceArray([50, 60, 70, 30, 25, 70], dtype=int32)

In [None]:
jnp.clip(scores_array, 20,59)

DeviceArray([50, 59, 59, 30, 25, 59], dtype=int32)

# Part Two

### Joining Two Arrays

In [None]:
array_two = jnp.array([90, 26, 37, 77, 65, 55])

In [None]:
scores_array

DeviceArray([50, 60, 70, 30, 25, 70], dtype=int32)

In [None]:
jnp.concatenate((scores_array, array_two))

DeviceArray([50, 60, 70, 30, 25, 70, 90, 26, 37, 77, 65, 55], dtype=int32)

### jnp.zeros()

In [None]:
jnp.zeros(5)

DeviceArray([0., 0., 0., 0., 0.], dtype=float32)

### jnp.ones()

In [None]:
jnp.ones(5)

DeviceArray([1., 1., 1., 1., 1.], dtype=float32)

In [None]:
jnp.eye(5) # Return a 2-D array with ones on the diagonal
           # and zeros elsewhere.

DeviceArray([[1., 0., 0., 0., 0.],
             [0., 1., 0., 0., 0.],
             [0., 0., 1., 0., 0.],
             [0., 0., 0., 1., 0.],
             [0., 0., 0., 0., 1.]], dtype=float32)

In [None]:
jnp.identity(5)

DeviceArray([[1., 0., 0., 0., 0.],
             [0., 1., 0., 0., 0.],
             [0., 0., 1., 0., 0.],
             [0., 0., 0., 1., 0.],
             [0., 0., 0., 0., 1.]], dtype=float32)

### jnp.linspace()

In [None]:
jnp.linspace(10,50,5) # Return evenly spaced numbers over a 
                      # specified interval.
                    # start,stop, num=5,

DeviceArray([10., 20., 30., 40., 50.], dtype=float32)

In [None]:
jnp.linspace(10,15,5)

DeviceArray([10.  , 11.25, 12.5 , 13.75, 15.  ], dtype=float32)

## Generating random numbers with JAX

In [20]:
print(np.random.random())
print(np.random.random())
print(np.random.random())

0.5537111303546517
0.7468391798711446
0.526741353307414


In [21]:
seed = 98
key = jax.random.PRNGKey(seed)

In [22]:
key

DeviceArray([ 0, 98], dtype=uint32)

In [23]:
jax.random.uniform(key)

DeviceArray(0.3756802, dtype=float32)

In [24]:
jax.random.uniform(key)

DeviceArray(0.3756802, dtype=float32)

In [25]:
key, subkey = jax.random.split(key)

In [27]:
subkey

DeviceArray([3614062411, 3294896607], dtype=uint32)

In [26]:
jax.random.uniform(subkey)

DeviceArray(0.95996785, dtype=float32)

In [None]:
jax.random.uniform(subkey)

DeviceArray(0.95996785, dtype=float32)

# Part 3

# Checking Documentation

In [None]:
help(jnp.linspace)

Help on function linspace in module jax._src.numpy.lax_numpy:

linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis: int = 0)
    Return evenly spaced numbers over a specified interval.
    
    LAX-backend implementation of :func:`numpy.linspace`.
    
    *Original docstring below.*
    
    Returns `num` evenly spaced samples, calculated over the
    interval [`start`, `stop`].
    
    The endpoint of the interval can optionally be excluded.
    
    .. versionchanged:: 1.16.0
        Non-scalar `start` and `stop` are now supported.
    
    .. versionchanged:: 1.20.0
        Values are rounded towards ``-inf`` instead of ``0`` when an
        integer ``dtype`` is specified. The old behavior can
        still be obtained with ``np.linspace(start, stop, num).astype(int)``
    
    Parameters
    ----------
    start : array_like
        The starting value of the sequence.
    stop : array_like
        The end value of the sequence, unless `endpoint` is set to

In [None]:
jnp.linspace?

[0;31mSignature:[0m
[0mnp[0m[0;34m.[0m[0mlinspace[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mstart[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mstop[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mnum[0m[0;34m=[0m[0;36m50[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mendpoint[0m[0;34m=[0m[0;32mTrue[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mretstep[0m[0;34m=[0m[0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mdtype[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0maxis[0m[0;34m:[0m [0mint[0m [0;34m=[0m [0;36m0[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
Return evenly spaced numbers over a specified interval.

LAX-backend implementation of :func:`numpy.linspace`.

*Original docstring below.*

Returns `num` evenly spaced samples, calculated over the
interval [`start`, `stop`].

The endpoint of the interval can optionally be excluded.

.. versionchanged:: 1.16.0
    Non-sc

JAX NumPy Operations

In [30]:
matrix = jnp.arange(17,33)

In [31]:
matrix = matrix.reshape(4,4)

In [32]:
matrix

DeviceArray([[17, 18, 19, 20],
             [21, 22, 23, 24],
             [25, 26, 27, 28],
             [29, 30, 31, 32]], dtype=int32)

In [49]:
try:
  jnp.sum([1, 2, 3])
except TypeError as e:
  print(f"TypeError: {e}")

TypeError: sum requires ndarray or scalar arguments, got <class 'list'> at position 0.


In [33]:
matrix.shape

(4, 4)

In [34]:
matrix.ndim

2

In [35]:
jnp.max(matrix)

DeviceArray(32, dtype=int32)

In [36]:
matrix

DeviceArray([[17, 18, 19, 20],
             [21, 22, 23, 24],
             [25, 26, 27, 28],
             [29, 30, 31, 32]], dtype=int32)

In [37]:
jnp.argmax(matrix)

DeviceArray(15, dtype=int32)

In [38]:
jnp.min(matrix)

DeviceArray(17, dtype=int32)

In [39]:
jnp.argmin(matrix)

DeviceArray(0, dtype=int32)

In [40]:
jnp.sum(matrix)

DeviceArray(392, dtype=int32)

In [41]:
jnp.sqrt(matrix)

DeviceArray([[4.1231055, 4.2426405, 4.358899 , 4.472136 ],
             [4.582576 , 4.690416 , 4.7958317, 4.8989797],
             [5.       , 5.0990195, 5.196152 , 5.2915025],
             [5.3851647, 5.477226 , 5.5677643, 5.656854 ]], dtype=float32)

In [42]:
matrix.transpose()

DeviceArray([[17, 21, 25, 29],
             [18, 22, 26, 30],
             [19, 23, 27, 31],
             [20, 24, 28, 32]], dtype=int32)

In [43]:
matrix.flatten()

DeviceArray([17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
             32], dtype=int32)

In [44]:
matrix.ravel()

DeviceArray([17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
             32], dtype=int32)

In [45]:
matrix2 = jnp.arange(1,17).reshape(4,4)

In [46]:
matrix + matrix2

DeviceArray([[18, 20, 22, 24],
             [26, 28, 30, 32],
             [34, 36, 38, 40],
             [42, 44, 46, 48]], dtype=int32)

In [53]:
matrix = np.arange(17,33)
matrix = matrix.reshape(4,4)
matrix2 = np.arange(1,17).reshape(4,4)

In [54]:
%timeit matrix * matrix2

The slowest run took 54.28 times longer than the fastest. This could mean that an intermediate result is being cached.
1000000 loops, best of 5: 560 ns per loop


In [59]:
%timeit jnp.dot(jnp.arange(17,33).reshape(4,4), jnp.arange(1,17).reshape(4,4)).block_until_ready() 

1000 loops, best of 5: 462 µs per loop


In [None]:
matrix /  matrix2

DeviceArray([[17.       ,  9.       ,  6.3333335,  5.       ],
             [ 4.2      ,  3.6666667,  3.2857144,  3.       ],
             [ 2.7777777,  2.6      ,  2.4545455,  2.3333333],
             [ 2.2307692,  2.142857 ,  2.0666666,  2.       ]],            dtype=float32)

In [None]:
matrix %  matrix2

DeviceArray([[0, 0, 1, 0],
             [1, 4, 2, 0],
             [7, 6, 5, 4],
             [3, 2, 1, 0]], dtype=int32)

## Device put

In [2]:
from jax import device_put
import numpy as np 
size = 5000
x = np.random.normal(size=(size, size)).astype(np.float32)
x = device_put(x)

In [3]:
x

DeviceArray([[-0.1144228 , -0.24054267, -0.766892  , ...,  1.5308956 ,
               0.92768997, -1.2955136 ],
             [-1.2012097 , -1.3346976 ,  1.0732094 , ...,  1.6548555 ,
              -0.11896823, -0.03097145],
             [-0.42850187,  0.48592672, -0.02518239, ..., -0.26078698,
              -0.5199622 ,  0.9433329 ],
             ...,
             [-1.2751442 ,  0.15932202, -0.18655974, ...,  1.8270756 ,
               1.4053702 ,  1.3269509 ],
             [ 0.90323037, -0.165994  , -0.04873463, ...,  0.3966682 ,
               0.9030303 ,  0.7333918 ],
             [ 0.5417945 , -0.8683455 , -2.1000135 , ...,  0.483317  ,
               0.13317995,  1.0398728 ]], dtype=float32)

# Indexing & Broadcasting in Numpy

In [None]:
matrix = jnp.arange(1,17)

In [None]:
matrix

DeviceArray([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15,
             16], dtype=int32)

In [None]:
matrix[0]

DeviceArray(1, dtype=int32)

## Out-of-Bounds Indexing

In [None]:
# Out-of-Bounds Indexing
matrix[20]

DeviceArray(16, dtype=int32)

In [None]:
matrix[2]

DeviceArray(3, dtype=int32)

In [None]:
matrix[2:6]

DeviceArray([3, 4, 5, 6], dtype=int32)

In [None]:
matrix[12:]

DeviceArray([13, 14, 15, 16], dtype=int32)

### Indexing Two Dimensional Array

In [None]:
matrix = matrix.reshape(4,4)

In [None]:
matrix[0] 

DeviceArray([1, 2, 3, 4], dtype=int32)

In [None]:
matrix[0:2]

DeviceArray([[1, 2, 3, 4],
             [5, 6, 7, 8]], dtype=int32)

In [None]:
matrix[1:3,1:3] # [startrow:endrow, startcolumn:endcolumn]

DeviceArray([[ 6,  7],
             [10, 11]], dtype=int32)

In [None]:
matrix[2:4,1:2]

DeviceArray([[10],
             [14]], dtype=int32)

In [None]:
matrix[2:4,2:4]

DeviceArray([[11, 12],
             [15, 16]], dtype=int32)

In [None]:
matrix[2:,2:]

DeviceArray([[11, 12],
             [15, 16]], dtype=int32)

### Broadcasting in NumPy

In [74]:
scores = [50,60,70,30,25]

In [75]:
scores_array = jnp.array(scores)

In [76]:
scores_array

DeviceArray([50, 60, 70, 30, 25], dtype=int32)

In [77]:
scores_array[0:3]

DeviceArray([50, 60, 70], dtype=int32)

### JAX arrays are immutable

In [None]:
scores_array[0:3] = [20,40,90]

TypeError: ignored

In [78]:
new_scores_array = scores_array.at[0:3].set([20,40,90])

In [79]:
new_scores_array

DeviceArray([20, 40, 90, 30, 25], dtype=int32)

## Using jit() to speed up functions

In [60]:
def test_fn(sample_rate=3000,frequency=3):
  x = jnp.arange(sample_rate) 
  y = np.sin(2*jnp.pi*frequency * (frequency/sample_rate)) 
  return jnp.dot(x,y)

In [61]:
%timeit test_fn()

The slowest run took 704.38 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 5: 76.1 µs per loop


In [62]:
test_fn_jit = jax.jit(test_fn)
%timeit test_fn_jit().block_until_ready()

The slowest run took 13996.83 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 5: 4.54 µs per loop


## How JIT works

By default JAX executes operations one at a time, in sequence.

Using a just-in-time (JIT) compilation decorator, sequences of operations can be optimized together and run at once.

Not all JAX code can be JIT compiled, as it requires array shapes to be static & known at compile time.

The fact that all JAX operations are expressed in terms of XLA allows JAX to use the XLA compiler to execute blocks of code very efficiently.

In [None]:
@jax.jit
def f(x, y):
  print("Running f():")
  print(f"  x = {x}")
  print(f"  y = {y}")
  result = jnp.dot(x + 1, y + 1)
  print(f"  result = {result}")
  return result

x = np.random.randn(3, 4)
y = np.random.randn(4)
f(x, y)

Running f():
  x = Traced<ShapedArray(float32[3,4])>with<DynamicJaxprTrace(level=0/1)>
  y = Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=0/1)>
  result = Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=0/1)>


DeviceArray([3.7037132, 1.5194231, 5.3635793], dtype=float32)

In [None]:
x2 = np.random.randn(3, 4)
y2 = np.random.randn(4)
f(x2, y2)

DeviceArray([-3.1307886, -2.3813279, -0.6085395], dtype=float32)

In [None]:
from jax import make_jaxpr

def f(x, y):
  return jnp.dot(x + 1, y + 1)

make_jaxpr(f)(x, y)

{ lambda ; a:f32[3,4] b:f32[4]. let
    c:f32[3,4] = add a 1.0
    d:f32[4] = add b 1.0
    e:f32[3] = dot_general[
      dimension_numbers=(((1,), (0,)), ((), ()))
      precision=None
      preferred_element_type=None
    ] c d
  in (e,) }

In [None]:
@jax.jit
def f(boolean, x):
  return -x if boolean else x

f(True, 1)

ConcretizationTypeError: ignored

In [None]:
from functools import partial

@partial(jax.jit, static_argnums=(0,))
def f(boolean, x):
  return -x if boolean else x

f(True, 1)

DeviceArray(-1, dtype=int32, weak_type=True)

In [None]:
f(False, 1)

DeviceArray(1, dtype=int32, weak_type=True)

## Taking derivatives with grad()

In [7]:
@jax.jit
def sum_logistic(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(6.)
derivative_fn = jax.grad(sum_logistic)
print(derivative_fn(x_small))

[0.25       0.19661194 0.10499357 0.04517666 0.01766271 0.00664806]


In [9]:
@jax.jit
def sum_logistic(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x))),(x + 1)

x_small = jnp.arange(6.)
derivative_fn = jax.grad(sum_logistic, has_aux=True)
print(derivative_fn(x_small))

(DeviceArray([0.25      , 0.19661194, 0.10499357, 0.04517666, 0.01766271,
             0.00664806], dtype=float32), DeviceArray([1., 2., 3., 4., 5., 6.], dtype=float32))


In [6]:
arcsinh = jax.grad(jax.numpy.arcsinh)
print(arcsinh(0.9))

0.7432942


## Auto-vectorization with vmap()

In [None]:
mat = jax.random.normal(key, (150, 100))
batched_x = jax.random.normal(key, (10, 100))
def apply_matrix(v):
  return jnp.dot(mat, v)

In [None]:
def naively_batched_apply_matrix(v_batched):
  return jnp.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()

Naively batched
The slowest run took 67.91 times longer than the fastest. This could mean that an intermediate result is being cached.
100 loops, best of 5: 2.35 ms per loop


In [None]:
@jax.jit
def batched_apply_matrix(v_batched):
  return jnp.dot(v_batched, mat.T)

print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()

Manually batched
The slowest run took 1684.56 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 5: 16.6 µs per loop


In [None]:
@jax.jit
def vmap_batched_apply_matrix(v_batched):
  return jax.vmap(apply_matrix)(v_batched)

print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()

Auto-vectorized with vmap
The slowest run took 449.09 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 5: 50.2 µs per loop


## Parallelization with pmap

In [3]:
x = np.arange(5)
w = np.array([2., 3., 4.])

def convolve(x, w):
  output = []
  for i in range(1, len(x)-1):
    output.append(jnp.dot(x[i-1:i+2], w))
  return jnp.array(output)

convolve(x, w)

DeviceArray([11., 20., 29.], dtype=float32)

In [4]:
n_devices = jax.local_device_count() 
xs = np.arange(5 * n_devices).reshape(-1, 5)
ws = np.stack([w] * n_devices)

xs

array([[ 0,  1,  2,  3,  4],
       [ 5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14],
       [15, 16, 17, 18, 19],
       [20, 21, 22, 23, 24],
       [25, 26, 27, 28, 29],
       [30, 31, 32, 33, 34],
       [35, 36, 37, 38, 39]])

In [5]:
jax.vmap(convolve)(xs, ws)

DeviceArray([[ 11.,  20.,  29.],
             [ 56.,  65.,  74.],
             [101., 110., 119.],
             [146., 155., 164.],
             [191., 200., 209.],
             [236., 245., 254.],
             [281., 290., 299.],
             [326., 335., 344.]], dtype=float32)

In [6]:
jax.pmap(convolve)(xs, ws)

ShardedDeviceArray([[ 11.,  20.,  29.],
                    [ 56.,  65.,  74.],
                    [101., 110., 119.],
                    [146., 155., 164.],
                    [191., 200., 209.],
                    [236., 245., 254.],
                    [281., 290., 299.],
                    [326., 335., 344.]], dtype=float32)

## Debugging NaNs

In [3]:
jnp.divide(0.0,0.0)



DeviceArray(nan, dtype=float32, weak_type=True)

In [4]:
from jax.config import config
config.update("jax_debug_nans", True)
jnp.divide(0.0,0.0)

Invalid nan value encountered in the output of a C++-jit/pmap function. Calling the de-optimized version.
Invalid value encountered in the output of a jit-decorated function. Calling the de-optimized version.


FloatingPointError: ignored

## Double (64bit) precision

In [5]:
x = jnp.float64(1.25844) 

  lax_internal._check_user_dtype_supported(dtype, "array")


In [6]:
x

DeviceArray(1.25844, dtype=float32)

In [7]:
# set this config at the begining of the program
from jax.config import config
config.update("jax_enable_x64", True)
x = jnp.float64(1.25844) 
x

DeviceArray(1.25844, dtype=float64)

## PyTrees

In [8]:
example_trees = [
    [1, 'a', object()],
    (1, (2, 3), ()),
    [1, {'k1': 2, 'k2': (3, 4)}, 5],
    {'a': 2, 'b': (2, 3)},
    jnp.array([1, 2, 3]),
]

In [9]:
# Let's see how many leaves they have:
for pytree in example_trees:
  leaves = jax.tree_leaves(pytree)
  print(f"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}")

[1, 'a', <object object at 0x7f280a01f6d0>]   has 3 leaves: [1, 'a', <object object at 0x7f280a01f6d0>]
(1, (2, 3), ())                               has 3 leaves: [1, 2, 3]
[1, {'k1': 2, 'k2': (3, 4)}, 5]               has 5 leaves: [1, 2, 3, 4, 5]
{'a': 2, 'b': (2, 3)}                         has 3 leaves: [2, 2, 3]
DeviceArray([1, 2, 3], dtype=int64)           has 1 leaves: [DeviceArray([1, 2, 3], dtype=int64)]


## Where to go from here
Follow us on [LinkedIn](https://www.linkedin.com/company/mlnuggets), [Twitter](https://twitter.com/ml_nuggets), [GitHub](https://github.com/mlnuggets) and subscribe to our [blog](https://www.machinelearningnuggets.com/#/portal) so that you don't miss a new issue.