# TODO Recording:

- Before starting recording make sure you show enabling line numbers
- Also make sure that the current runtime is GPU

# What is JAX?

**JAX** is a framework that is specifically suited for Machine Learning Research. A few points about JAX:
1. It's just like `numpy` but uses a compiler (XLA) to compile native Numpy code, and runs on acceleartors (GPU/TPU)
2. For automatic differentiation, JAX uses `Autograd`. It automatically differentiates native Python and Numpy code.
3. JAX is used to express numerical programs as compositions but with certain constraints e.g. JAX transformation and compilation are designed to work only on Python functions that are functionally pure. A function is pure if it always returns the same value when invoked with same arguments, and the function has no-side affect e.g. chaning the state of a non-local variables
4. In terms of syntax, JAX is very very similar to numpy but there are subtle differences that you should be aware of. 


In [1]:
!pip install jax

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [1]:
import jax
import numpy as np
import jax.numpy as jnp

from jax import random

In [2]:
jax.__version__

'0.3.14'

# NumPy arrays vs JAX arrays

`jax numpy` is very very similar to `numpy` in terms of API. *Most of the operations* that you do in numpy are also available in jax numpy with similar semantics. I am just listing down a few operations to showcase this but there are many more. Please check the [docs](https://jax.readthedocs.io/en/latest/jax.numpy.html) to see the list of functions that are available.

**Note:** Not all Numpy functions are implemented in JAX numpy (..yet)

Here we are comparing numpy arrays with Jax arrays

In [3]:
array_np = np.array([1, 2, 3, 4, 5])
array_np

array([1, 2, 3, 4, 5])

In [4]:
array_jax = jnp.array([1, 2, 3, 4, 5])
array_jax

DeviceArray([1, 2, 3, 4, 5], dtype=int32)

data types are passed explicitly



In [5]:
array_np = np.array([1, 2, 3, 4, 5], dtype = np.int32)
array_np

array([1, 2, 3, 4, 5], dtype=int32)

In [6]:
array_jax = jnp.array([1, 2, 3, 4, 5], dtype = jnp.int32)
array_jax

DeviceArray([1, 2, 3, 4, 5], dtype=int32)

# DeviceArray

`array_np` is an object of **`ndarray`** while `array_jax` is an object of **`DeviceArray`**. 

Following are the points that you should know about **`DeviceArray`**:
1. It is the core underlying JAX array object, similar to `ndarray` but with subtle differences
2. Unlike `ndarray`, `DeviceArray` is backed by a memory buffer on a single device (CPU/GPU/TPU)
3. It is **device-agnostic** i.e. JAX doesn't need to track the device on which the array is present, and can avoid data transfers
4. Because it is device agnostic, this makes it easy to run the same JAX code on CPU, GPU, or TPU with no code changes
5. `DeviceArray` is **lazy** i.e. the value of a JAX `DeviceArray` isn't immediately available and is only pulled when requested.
6. Even though `DeviceArray` is lazy, you can still do operations like inspecting the shape or type of a DeviceArray without waiting for the computation that produced it to complete. We can even pass it to another JAX computation.

The two properties **lazy evaluation**, and being **device-agnostic** give **`DeviceArray`** a huge advantage and greatly improve the performance of computations we perform using JAX

In [7]:
print('Type of Numpy Array :', type(array_np))
print('Type of JAX Array :', type(array_jax))

Type of Numpy Array : <class 'numpy.ndarray'>
Type of JAX Array : <class 'jaxlib.xla_extension.DeviceArray'>


In [8]:
array_np = np.arange(10)
array_jax = jnp.arange(10)

array_np, array_jax

(array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
 DeviceArray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32))

In [9]:
array_np = np.linspace(1, 10, 10)
array_jax = jnp.linspace(1, 10, 10)

array_np, array_jax

(array([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.]),
 DeviceArray([ 1.       ,  2.       ,  3.       ,  4.       ,  5.       ,
               6.       ,  7.0000005,  8.       ,  9.       , 10.       ],            dtype=float32))

Sum and mean works in both numpy and JAX arrays same way

In [10]:
print('Sum of numpy array elements :', array_np.sum())
print('Sum of JAX array elements :', array_jax.sum())

Sum of numpy array elements : 55.0
Sum of JAX array elements : 55.0


In [11]:
print('Mean of numpy array elements :', array_np.mean())
print('Mean of JAX array elements :', array_jax.mean())

Sum of numpy array elements : 5.5
Sum of JAX array elements : 5.5


Some of other operations are also comapred like transpose,reshape,Matrix multiplication

In [12]:
array_np = np.array([[1, 2, 3], [4, 5, 6]])
print('Numpy array:\n', array_np)

array_np_transposed = array_np.T
print('\nNumpy array transposed:\n', array_np_transposed)

Numpy array:
 [[1 2 3]
 [4 5 6]]

Numpy array transposed:
 [[1 4]
 [2 5]
 [3 6]]


In [13]:
array_jax = jnp.array([[1, 2, 3], [4, 5, 6]])
print('JAX array:\n', array_jax)

array_jax_transposed = array_jax.T
print('\nJAX array transposed:\n', array_jax_transposed )

JAX array:
 [[1 2 3]
 [4 5 6]]

JAX array transposed:
 [[1 4]
 [2 5]
 [3 6]]


In [14]:
print('Original shape of Numpy array:', array_np.shape)
print('Original shape of JAX array:', array_jax.shape)

array_np_reshaped = array_np.reshape(1, -1)
array_jax_reshaped = array_jax.reshape(1, -1)

print('\nNew shape of Numpy array:', array_np_reshaped.shape)
print('New shape of JAX array:', array_jax_reshaped.shape)

Original shape of Numpy array: (2, 3)
Original shape of JAX array: (2, 3)

New shape of Numpy array: (1, 6)
New shape of JAX array: (1, 6)


NumPy is deeply integrated with Python and can operate on Python lists

In [15]:
np.sum([2, 3, 4, 6])

15

JAX does not work on Python lists. This has to do with the fact that lists and tuples can perform poorly when used with the JIT compiler in JAX, so JAX just does not support Python structures

In [16]:
jnp.sum([2, 3, 4, 6])

TypeError: ignored

Now, let's take a look at some of the things that you can do in Numpy but not in Jax-numpy and vice-versa

# Immutability

JAX arrays are **immutable**, just like [**TensorFlow tensors**](https://www.kaggle.com/aakashnain/tf-jax-tutorials-part1). Meaning, JAX arrays don't support `item assignment` as you do in `ndarray`. Let's take an example!

In [16]:
array_np = np.arange(10, dtype = np.int32)
array_jax = jnp.arange(10, dtype = jnp.int32)

array_np, array_jax

(array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32),
 DeviceArray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32))

In [17]:
array_np[4] = 22222

array_np

array([    0,     1,     2,     3, 22222,     5,     6,     7,     8,
           9], dtype=int32)

Item assignment not supported in JAX array

In [19]:
array_jax[4] = 22222

TypeError: ignored

In [18]:
array_jax_modified = array_jax.at[4].set(2222)

array_jax_modified

DeviceArray([   0,    1,    2,    3, 2222,    5,    6,    7,    8,    9], dtype=int32)

In [19]:
array_jax

DeviceArray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)

 **Why immutability?** JAX relies on **pure functions** to JIT compile code to get greatly improved performance.
 
JAX uses this immutability of arrays in order to optimize the computation graph for your scientific operations. This requires that JAX should be able to fuse operations, eliminate dependencies, and parallelize operations. Allowing item-assignment or in-place updates makes this kind of graph optimization very difficult.

In [20]:
print(array_jax.at[5].add(10))

print(array_jax.at[5].mul(2))

[ 0  1  2  3  4 15  6  7  8  9]
[ 0  1  2  3  4 10  6  7  8  9]


# Asynchronous dispatch

One of the biggest differences between `ndarrays` and `DeviceArrays` is in their execution and their availability. JAX uses asynchronous dispatch to hide Python overheads. Let's take an example to understand what it means.

In [21]:
array_np = np.random.normal(size = (10000, 10000)).astype(np.float32)
array_jax = jax.random.normal(jax.random.PRNGKey(0), (10000, 10000), dtype = jnp.float32) 

print('Shape of numpy array: ', array_np.shape)
print('Shape of JAX array: ', array_jax.shape)

Shape of numpy array:  (10000, 10000)
Shape of JAX array:  (10000, 10000)


In [22]:
%time np.matmul(array_np, array_np)

print('Completed NumPy operation')

CPU times: user 45.1 s, sys: 101 ms, total: 45.2 s
Wall time: 23 s
Completed NumPy operation


During warm up it takes a bit more time.Subsequent runs are faster than 1st run

In [24]:
%time jnp.matmul(array_jax, array_jax)

print('Completed JAX operation')

CPU times: user 1.51 ms, sys: 11 µs, total: 1.53 ms
Wall time: 1.42 ms
Completed JAX operation?


Now, let's do some computation on each array to see what happens and how much time does each computation take

Wow! Seems that the `DeviceArray` computation finished in no time. This is where you should remember this:
1. Unlike the result of `ndarray`, the result of the computation done on DeviceArray isn't available yet. This is a **future** value that will be available on the accelerator 
2. You can retrieve the value of this computation by **printing** it or by converting it into a plain old numpy `ndarray`
3. The above timing for DeviceArray is the time taken to **dispatch** the work, not the time taken for actual computation
4. Asynchronous dispatch is useful since it allows Python code to “run ahead” of an accelerator device, keeping Python code out of the critical path. Provided the Python code enqueues work on the device faster than it can be executed, and that Python code does not actually need to inspect the output of a computation on the host, then a Python program can enqueue arbitrary amounts of work and avoid having the accelerator wait.
5. To measure the true cost of any such operation:
     - Either convert it to plain numpy ndarray (not preferred)
     - Use `block_until_ready()` to wait for the computation that produced it to complete (preferred way for benchmarking)
     
Let's take a look at the above two methods again to measure the correct computation time

In [43]:
%time np.asarray(jnp.matmul(array_jax, array_jax))

CPU times: user 693 ms, sys: 0 ns, total: 693 ms
Wall time: 691 ms


array([[ -90.649605 ,   13.972294 ,  -95.644356 , ...,   23.926107 ,
         133.53967  ,   53.144093 ],
       [  44.89995  ,  -33.371754 ,   94.96752  , ..., -100.38466  ,
         -56.9396   , -217.22105  ],
       [ -45.115227 , -185.57512  , -189.55267  , ..., -213.17244  ,
          11.189074 ,   18.810091 ],
       ...,
       [  57.29252  ,   89.313675 ,   96.97262  , ...,  -26.702463 ,
          32.321266 ,  159.42375  ],
       [  92.361015 ,   29.874603 ,  -63.730263 , ...,   41.16846  ,
         154.73875  ,   85.76176  ],
       [   7.0634017,   81.06211  ,  145.52782  , ...,   73.70803  ,
         -43.59744  ,  -10.067663 ]], dtype=float32)

In [44]:
%time jnp.matmul(array_jax, array_jax).block_until_ready()

CPU times: user 4.59 ms, sys: 0 ns, total: 4.59 ms
Wall time: 593 ms


DeviceArray([[ -90.649605 ,   13.972294 ,  -95.644356 , ...,
                23.926107 ,  133.53967  ,   53.144093 ],
             [  44.89995  ,  -33.371754 ,   94.96752  , ...,
              -100.38466  ,  -56.9396   , -217.22105  ],
             [ -45.115227 , -185.57512  , -189.55267  , ...,
              -213.17244  ,   11.189074 ,   18.810091 ],
             ...,
             [  57.29252  ,   89.313675 ,   96.97262  , ...,
               -26.702463 ,   32.321266 ,  159.42375  ],
             [  92.361015 ,   29.874603 ,  -63.730263 , ...,
                41.16846  ,  154.73875  ,   85.76176  ],
             [   7.0634017,   81.06211  ,  145.52782  , ...,
                73.70803  ,  -43.59744  ,  -10.067663 ]], dtype=float32)

In [25]:
%timeit -n5 np.dot(array_np, array_np.T)

5 loops, best of 5: 11.9 s per loop


In [26]:
%timeit -n5 jnp.dot(array_jax, array_jax.T)

The slowest run took 3547.42 times longer than the fastest. This could mean that an intermediate result is being cached.
5 loops, best of 5: 138 µs per loop


In [27]:
# Runs on the GPU with transfer overhead

%timeit -n5 jnp.dot(array_np, array_np.T)

5 loops, best of 5: 267 ms per loop


In [28]:
from jax import device_put

array_np_gpu = device_put(array_np)

%timeit -n5 jnp.dot(array_np_gpu, array_np_gpu.T)

The slowest run took 3403.05 times longer than the fastest. This could mean that an intermediate result is being cached.
5 loops, best of 5: 145 µs per loop


# Types promotion

This is another aspect to keep in mind. `dtype` promotion in JAX is less aggressve as compared to numpy. A few things:
1. JAX always prefers the precision of the JAX value when promoting a Python scalar
2. JAX always prefers the type of the floating-point or complex type when promoting an integer or boolean type against floating or complex type
3. JAX uses floating point promotion rules that are more suited to modern accelerator devices like GPUs/TPUs

Let's take an example to see these in action

Numpy promoting to int64 from int16

In [46]:
np.int16(128).dtype

dtype('int16')

In [47]:
(np.int16(128) + 4).dtype

dtype('int64')


jnp.int16(128) + 4 will return int16 rather than promoting to int64 as in NumPy.

In [48]:
(jnp.int16(128) + 4).dtype

dtype('int16')

Comparing Numpy and JAX casting.
Implicit numpy casting gives:  float64
Implicit JAX casting gives:  float32

In [49]:
array_int_np = np.random.randint(10, size = 10, dtype = np.int32)

array_int_np

array([8, 0, 5, 0, 8, 3, 6, 7, 8, 7], dtype=int32)

In [50]:
(array_int_np + 5.0).dtype

dtype('float64')

In [51]:
array_int_jax = jax.random.randint(jax.random.PRNGKey(0),
                            minval = 0,
                            maxval = 9,
                            shape = [10],
                            dtype = jnp.int32
                           )

array_int_jax

DeviceArray([8, 1, 8, 1, 2, 0, 3, 4, 1, 3], dtype=int32)

In [52]:
(array_int_jax + 5.0).dtype

dtype('float32')