#What is the JAX ?

<img src="https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png" width="640">

### JAX is a Deep Learning Framework written by google

JAX is Autograd and XLA, brought together for high-performance numerical computing and machine learning research. It provides composable transformations of Python+NumPy programs: differentiate, vectorize, parallelize, Just-In-Time compile to GPU/TPU, and more.

[GITHUB Of JAX: ](https://github.com/google/jax)

[JAX reference documentation](https://jax.readthedocs.io/en/latest/)


###XLA (accelerated linear algebra) is a compiler-based linear algebra execution engine. It is the backend that powers machine learning frameworks such as TensorFlow and JAX , on a variety of devices including CPUs, GPUs, and TPUs.

You've probably heard of TensorFlow and PyTorch, and maybe you've even heard of MXNet - but there is a new kid on the block of machine learning frameworks - Google's JAX.

Over the last two years, JAX has been taking deep learning research by storm, facilitating the implementation of Google's Vision Transformer (ViT) and powering research at DeepMind.

So what is so exciting about the new JAX framework?

#JAX at Large
Boiled down, JAX is python's numpy with automatic differentiation and optimized to run on GPU. The seamless translation between writing numpy and writing in JAX has made JAX popular with machine learning practitioners.

JAX offers four main function transformations that make it efficient to use when executing deep learning workloads.

##JAX Four Function Transformations
#1.grad 
- automatically differentiates a function for backpropagation. You can take grad to any derivative order.

In [1]:
from jax import grad
import jax.numpy as jnp

def tanh(x):  # Define a function
  y = jnp.exp(-2.0 * x)
  return (1.0 - y) / (1.0 + y)

grad_tanh = grad(tanh)  # Obtain its gradient function
print(grad_tanh(1.0))   # Evaluate it at x = 1.0
# prints 0.4199743

0.4199743


#2. jit 
- auto-optimizes your functions to run their operations efficiently. Can also be used as a function decorator.

In [2]:
import jax.numpy as jnp
from jax import jit

def slow_f(x):
  # Element-wise ops see a large benefit from fusion
  return x * x + x * 2.0

x = jnp.ones((5000, 5000))
fast_f = jit(slow_f)
%timeit -n10 -r3 fast_f(x)  
%timeit -n10 -r3 slow_f(x)  

The slowest run took 239.21 times longer than the fastest. This could mean that an intermediate result is being cached.
10 loops, best of 3: 32.9 µs per loop
The slowest run took 4.95 times longer than the fastest. This could mean that an intermediate result is being cached.
10 loops, best of 3: 4.71 ms per loop


#3. vmap 
- maps a function across dimensions. Means that you don't have to keep track of dimensions as carefully when passing a batch through, for example.

In [4]:
predictions = vmap(predict, in_axes=(None, 0))(params, input_batch)

#4. pmap 
- maps processes across multiple processors, like multi-GPU

In [6]:
from jax import random, pmap
import jax.numpy as jnp

# Create 8 random 5000 x 6000 matrices, one per GPU
keys = random.split(random.PRNGKey(0), 8)
mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys)

# Run a local matmul on each device in parallel (no data transfer)
result = pmap(lambda x: jnp.dot(x, x.T))(mats)  # result.shape is (8, 5000, 5000)

# Compute the mean on each device in parallel and print the result
print(pmap(jnp.mean)(result))
# prints [1.1566595 1.1805978 ... 1.2321935 1.2015157]

<img src="https://theaisummer.com/static/65961ba55109646b3aed515c7dba67cb/ee604/jax-tensorflow-pytorch.png" width="740">

#JAX vs PyTorch
The nearest machine learning framework to JAX is PyTorch. That is because they share their roots in striving to be as "numpy-esque" as possible.

JAX's functionality with lower level function definitions makes it preferrable for certain research tasks.

That said, PyTorch offers a much further breadth of libraries and utilities, pre-trained and pre-written network definitions, a data loader, and portability to deployment destinations.

#JAX vs TensorFlow
JAX and TensorFlow were both written by Google. From my initial experimentation, JAX seems much easier to develop in and is more intuitive.

That said, JAX lacks the extensive infrastructure that TensorFlow has built over the years - be it open source projects, pre-trained models, tutorials, higher level abstractions (via Keras), and portability to deployment destinations.

#What JAX lacks?
- A Data Loader - you'll need to implement your own or hop over to TensorFlow or PyTorch to borrow one,.

- Higher level model abstractions
- Deployment portability

# When should we use JAX?
JAX is a new machine learning framework that has been gaining popularity in machine learning research.

If you're operating in the research realm, JAX is a good option for your project.

If you're actively developing an application, PyTorch and TensorFlow frameworks will move your initiative along with greater velocity. And of course, in computer vision there is always a tradeoff to weigh in building vs buying computer vision tooling.

Thanks for reading on JAX! Happy Learning, and of course, happy inferencing!