In [6]:
import jax.numpy as jnp
import numpy as np

# Special transform functions (we'll understand what these are very soon!)
from jax import grad, jit, vmap, pmap
from jax import make_jaxpr
from jax import random
from jax import device_put
import matplotlib.pyplot as plt

In [7]:
seed = 0
key = random.PRNGKey(seed)

# Fact 4: JAX is AI accelerator agnostic. Same code runs everywhere!

size = 3000

# Data is automagically pushed to the AI accelerator! (DeviceArray structure)
# No more need for ".to(device)" (PyTorch syntax)
x_jnp = random.normal(key, (size, size), dtype=jnp.float32)
x_np = np.random.normal(size=(size, size)).astype(np.float32)  # some diff in API exists!

%timeit jnp.dot(x_jnp, x_jnp.T).block_until_ready()  # 1) on GPU - fast
%timeit np.dot(x_np, x_np.T)  # 2) on CPU - slow (NumPy only works with CPUs)
%timeit jnp.dot(x_np, x_np.T).block_until_ready()  # 3) on GPU with transfer overhead

x_np_device = device_put(x_np)  # push NumPy explicitly to GPU
%timeit jnp.dot(x_np_device, x_np_device.T).block_until_ready()  # same as 1)

# Note1: I'm using GPU as a synonym for AI accelerator. 
# In reality, especially in Colab, this can also be a TPU, etc.


10 loops, best of 5: 25.9 ms per loop
1 loop, best of 5: 424 ms per loop
10 loops, best of 5: 92.9 ms per loop
10 loops, best of 5: 24.4 ms per loop


- It is device-agnostic i.e. JAX doesn't need to track the device on which the array is present, and can avoid data transfers

- Because it is device agnostic, this makes it easy to run the same JAX code on CPU, GPU, or TPU with no code changes