# JAX Intro
This notebook is a brief introduction to JAX. After working through this notebook, you will have a basic idea of 
1. What JAX is
2. Why it's useful
3. How to use it at a basic level

In [None]:
import jax
import jax.numpy as jnp  # The numpy api implemented in JAX

## Why use JAX?
![](https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png)

In Google's words:
>JAX is Autograd and XLA, brought together for high-performance numerical computing, including large-scale machine learning research.

That's great if you know what Autograd and XLA are. Let's assume we don't.

### Autograd
[Autograd](https://github.com/hips/autograd) is an automatic differentiation library for pure Python and NumPy.
At a basic level, this means we can efficiently compute gradients in our code

### XLA
[XLA](https://github.com/openxla/xla) stands for Accelerated Linear Algebra. This is a compiler for common machine learning libraries in Python that allows us to compile our code for GPUs, CPUs, NPUs, TPUs, and whatever acronyms we come up with next for new accelerators.

### Takeaway
JAX gives us fast, compiled code that can run on accelerators and efficiently compute gradients. This is _extremely_ useful not only for machine learning and AI, but also for inference, simulation, etc. in the sciences.

## Some JAX Basics
One of the nicest things about JAX is that it implements the [NumPy API](https://jax.readthedocs.io/en/latest/jax.numpy.html).
We imported this earlier as  ```jnp```.

In [None]:
ones = jnp.ones((3, 3))
ones

In [None]:
twos = jnp.ones((3, 3)) * 2
twos

In [None]:
threes = ones + twos
threes

In [None]:
threes.devices()  # CUDA, CPU, TPU, etc.

As you can see, we can manipulate arrays with JAX (mostly) as we expect.
I will show an example of how things in JAX are a bit different.
Let's say we want to change one of the array elements:

In [None]:
threes[0, 1] += 1

**We get an error!**

Why? It's because JAX requires us to write *functional* code. 
When I say functional, I mean in the sense of functional programming.
The above operation is an example of a *side effect*, something not allowed in functional programming.
We must instead have a function that modifies the existing array and returns a new one.
Luckily, JAX provides these:

In [None]:
threes.at[0, 1].add(1)

In [None]:
# But this is still the same
threes

To be **functional**, we must return the value to some output

In [None]:
threes = threes.at[0, 1].add(1)
threes

This pattern will repeat throughout JAX codebases.
Most of JAX's objects (like these arrays) are what we would call **immutable**.
Think of tuples in pure Python.

## Taking Gradients in JAX
We will use `jax.grad` to compute a gradient of a function written with JAX

In [None]:
def sum_of_squares(x):
    return jnp.sum(x**2)

In [None]:
sum_of_squares_dx = jax.grad(sum_of_squares)

x = jnp.asarray([1.0, 2.0, 3.0, 4.0])

print(sum_of_squares(x))

print(sum_of_squares_dx(x))

We can also get back the value and gradient at the same time

In [None]:
sum_of_squares_val_dx = jax.value_and_grad(sum_of_squares)

x = jnp.asarray([1.0, 2.0, 3.0, 4.0])

print(sum_of_squares(x))

print(sum_of_squares_val_dx(x))

## JIT (Just in Time) Compilation
We can also compile JAX functions using XLA.
To do this, JAX builds an intermediate representation of the code called a *jaxpr* (JAX exPRession).
The XLA compiler then reads this *jaxpr*.

Let's see an example of a jaxpr

In [None]:
jax.make_jaxpr(sum_of_squares)(x)

Notice how the above is broken down into a series of primitive operations.

Now, let's look at compiling the above function.

In [None]:
sum_of_squares_jit = jax.jit(sum_of_squares)
sum_of_squares_jit(x).block_until_ready()
z = x.repeat(1_000_000, axis=0)

In [None]:
%%timeit
sum_of_squares_jit(z).block_until_ready()

In [None]:
%%timeit
sum_of_squares(z)

A very silly example, but JAX does give us a speedup. This would be more evident if executed on a GPU.

## Automatic Vectorization
You may be familiar with the idea of *vectorized computations*, a sub-class of *single instruction, multiple data* parallel processing.
See the following example of array addition

In [None]:
a, b = jnp.arange(5), jnp.arange(5)
jax.make_jaxpr(lambda a, b: a * b)(
    a, b
)  # This is a single primitive! It is done in parallel

### What's going on here?
Our CPU (or GPU, TPU) have vector instruction sets. 
They know how to execute a single instruction (multiplication) on multiple data (the array elements) in parallel.
The specifics of this are processor-specific and get complicated very quickly.

![](./imgs/vectorization.png)

(Image credit https://datascience.blog.wzb.eu/)

### What if we don't use a built-in vectorized function?
Most NumPy array math and linear algebra routines are vectorized (as are the JAX counterparts), but what happens if you want to vectorize your own custom code?
JAX has a utility to automatically do this: `jax.vmap`

Below we have a function that takes a windowed, weighted average of an input array

In [None]:
def convolve(x, w):
    output = []
    for i in range(1, len(x) - 1):
        output.append(jnp.dot(x[i - 1 : i + 2], w))
    return jnp.array(output)


x = jnp.arange(5)
w = jnp.array([2.0, 3.0, 4.0])

convolve(x, w)

Let's assume we actually have a batch of arrays and weights. We want to vectorize our function across this batch dimension.

In [None]:
xs = jnp.stack([x, 2 * x])
ws = jnp.stack([w, w])
xs, ws

It's not too difficult to manually update the code to be vectorized across the batch dimension

In [None]:
def manually_vectorized_convolve(xs, ws):
    output = []
    for i in range(1, xs.shape[-1] - 1):
        output.append(jnp.sum(xs[:, i - 1 : i + 2] * ws, axis=1))
    return jnp.stack(output, axis=1)


manually_vectorized_convolve(xs, ws)

These array gymnastics get tiring, and it's not always this easy. We can use JAX to make this trivial

In [None]:
vmap_convolve = jax.vmap(convolve)

vmap_convolve(xs, ws)

We can also combine JIT with vmap to compile our vectorized function

In [None]:
jitted_vmap_convolve = jax.jit(vmap_convolve)

jitted_vmap_convolve(xs, ws)

## The Broader JAX Ecosystem
- [Flax](https://flax.readthedocs.io/en/latest/): Neural networks with JAX
- [Optax](https://optax.readthedocs.io/en/latest/): gradient processing and optimization library
- [Equinox](https://docs.kidger.site/equinox/): Neural nets, JIT utilities, and more
- [NumPyro](https://github.com/pyro-ppl/numpyro): probabilistic programming built on JAX
- [BlackJAX](https://blackjax-devs.github.io/blackjax/): collection of samplers for Bayesian inference
- [jax.scipy](https://jax.readthedocs.io/en/latest/jax.scipy.html): implementation of popular SciPy routines

- [AwesomeJAX](https://github.com/n2cholas/awesome-jax): curated list of popular JAX packages