# An Introduction to JAX

JAX is a framework for fast linear algebra operations and automatic differentiation.

Loosely speaking, JAX is like NumPy with the addition of

* automatic differentiation
* automated GPU support
* a just-in-time compiler

JAX is often used for machine learning and AI, since it can scale to big data operations on GPUs and automatically differentiate loss functions for gradient decent.

Here is a short history of JAX:

* 2015: Google open-sources part of its AI infrastructure called TensorFlow.
* 2016: The popularity of TensorFlow grows rapidly.
* 2017: Facebook open-sources PyTorch beta, an alternative AI framework (developer-friendly, more Pythonic)
* 2018: Facebook launches a full production-ready version of PyTorch.
* 2019: PyTorch explodes in popularity (adopted by Uber, Airbnb, Tesla, etc.)
* 2020: Google launches JAX as an open-source framework.
* 2021: Google starts to shift away from TPUs to Nvidia GPUs, JAX development accelerates.
* 2022: JAX popularity begins to take off.

## Installation

JAX can be installed with or without GPU support.

* Follow [the install guide](https://github.com/google/jax)

Note that JAX is pre-installed with GPU support on [Google Colab](https://colab.research.google.com/).

(Colab Pro offers better GPUs.)

## JAX as a NumPy Replacement

One way to use JAX is as a direct NumPy replacement.  Let's look at the similarities and differences.

### Similarities

The following import is standard, replacing `import numpy as np`:

In [115]:
import jax.numpy as jnp

Now we can use `jnp` in place of `np` for the usual array operations:

In [116]:
a = jnp.asarray((1.0, 3.2, -1.5))

In [117]:
print(a)

[ 1.   3.2 -1.5]


In [118]:
print(jnp.sum(a))

2.6999998


In [119]:
print(jnp.mean(a))

0.9


In [120]:
print(jnp.dot(a, a))

13.490001


However, the array object `a` is not a NumPy array:

In [121]:
type(a)

jaxlib.xla_extension.DeviceArray

Likewise, scalar deductions on arrays are of type `DeviceArray`:

In [122]:
type(jnp.sum(a))

jaxlib.xla_extension.DeviceArray

In [123]:
jnp.sum(a)

DeviceArray(2.6999998, dtype=float32)

The term `Device` refers to GPUs, although JAX falls back to the CPU if no GPU is connected.

(In the terminology of GPUs, the "host" is the machine that launches GPU operations, while the "device" is the GPU itself.)

If JAX is installed with GPU support, then JAX uses 32 bit floats by default.  This is standard for GPU computing.

Operations on higher dimensional arrays is also similar to NumPy:

In [124]:
A = jnp.ones((2, 2))
B = jnp.identity(2)
A @ B

DeviceArray([[1., 1.],
             [1., 1.]], dtype=float32)

In [125]:
from jax.numpy import linalg

In [126]:
linalg.solve(B, A)

DeviceArray([[1., 1.],
             [1., 1.]], dtype=float32)

In [127]:
linalg.eigh(B)  # Computes eigenvalues and eigenvectors

(DeviceArray([0.99999994, 0.99999994], dtype=float32),
 DeviceArray([[1., 0.],
              [0., 1.]], dtype=float32))

### Differences

As a NumPy replacement, the biggest difference is that arrays are treated as **immutable**.  For example, with NumPy we can write 

In [128]:
import numpy as np
a = np.linspace(0, 1, 3)
a

array([0. , 0.5, 1. ])

In [102]:
a[0] = 1
a

array([1. , 0.5, 1. ])

In JAX this fails:

In [103]:
a = jnp.linspace(0, 1, 3)
a

DeviceArray([0. , 0.5, 1. ], dtype=float32)

In [104]:
a[0] = 1

TypeError: '<class 'jaxlib.xla_extension.DeviceArray'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html

In line with immutability, JAX does not support inplace operations:

In [108]:
a = np.array((2, 1))
a.sort()
a

array([1, 2])

In [109]:
a = jnp.array((2, 1))
a.sort()
a

DeviceArray([2, 1], dtype=int32)

The designers of JAX chose to make arrays immutable because JAX uses a functional programming style.  More on this below.  

Note that, while mutation is discouraged, it is in fact possible with `at`, as in

In [110]:
a = jnp.linspace(0, 1, 3)
id(a)

139862896932400

In [111]:
a

DeviceArray([0. , 0.5, 1. ], dtype=float32)

In [112]:
a.at[0].set(1)

DeviceArray([1. , 0.5, 1. ], dtype=float32)

We can check that the array is mutated by verifying its identity is unchanged:

In [114]:
id(a)

139862896932400

## Random Numbers



Random numbers are also a bit different in JAX than in NumPy.  Typically, in JAX, the state of the random number generator needs to be controlled explicitly.

In [43]:
import jax.random as random

First we produce a key, which seeds the random number generator.

In [63]:
key = random.PRNGKey(1)

In [64]:
type(key)

jaxlib.xla_extension.DeviceArray

Now we can use the key to generate some random numbers:

In [65]:
x = random.normal(key, (3, 3))
x

DeviceArray([[ 0.690805  , -0.48744103, -1.155789  ],
             [ 0.12108463,  1.2010182 , -0.5078766 ],
             [ 0.91568655,  1.70968   , -0.36749417]], dtype=float32)

If we use the same key again, we initialize at the same seed, so the random numbers are the same:

In [66]:
random.normal(key, (3, 3))

DeviceArray([[ 0.690805  , -0.48744103, -1.155789  ],
             [ 0.12108463,  1.2010182 , -0.5078766 ],
             [ 0.91568655,  1.70968   , -0.36749417]], dtype=float32)

To produce a (quasi-) independent draw, best practice is to "split" the existing key:

In [67]:
key, subkey = random.split(key)

In [68]:
random.normal(key, (3, 3))

DeviceArray([[-0.64377284,  0.7696183 , -0.29809612],
             [ 0.47858787,  1.3699535 ,  1.2741846 ],
             [-0.4408543 , -0.2564722 ,  1.4826155 ]], dtype=float32)

In [69]:
random.normal(subkey, (3, 3))

DeviceArray([[-0.58221203,  0.5907554 , -0.22852424],
             [ 1.8015778 ,  0.22681691, -0.07008386],
             [ 0.9982648 , -0.1906528 ,  0.22489986]], dtype=float32)

The function below produces `k` (quasi-) independent random `n x n` matrices using this procedure.

In [79]:
def gen_random_matrices(key, n, k):
    matrices = []
    for _ in range(k):
        key, subkey = random.split(key)
        matrices.append(random.uniform(subkey, (n, n)))
    return matrices


In [80]:
matrices = gen_random_matrices(key, 2, 2)
for A in matrices:
    print(A)

[[0.5101472  0.26307964]
 [0.037184   0.5259018 ]]
[[0.22010922 0.736096  ]
 [0.82404494 0.28857923]]


One point to remember is that JAX expects tuples to describe array shapes, even for flat arrays.  Hence, to get a one-dimensional array of normal random draws we use `(len, )` for the shape, as in

In [82]:
random.normal(key, (5, ))

DeviceArray([ 0.9087967 , -0.040249  ,  0.17204419, -1.6576358 ,
              0.353745  ], dtype=float32)

## Functional Programming

According to the JAX documentation, the proper way to use JAX is to use it only on functionally pure Python functions.