# JAX Basics

In [1]:
import jax
from jax import numpy as jnp
import numpy as np


# cool feature no. 1

x = np.array([1.0, 2.0])[None].T
y = jnp.array([1.0, 2.0])[None].T

# numpy arrays are transformed to JAX tensors automatically at operation execution (no to_tensor mumbo-jumbo)
print("Dot product:", x.T@y)



Dot product: [[5.]]


In [4]:

M = np.random.uniform(-0.1, 0.1, (500, 500))

def numpy_function(M):
    return M @ M.T


def function(M):
    return M @ M.T


@jax.jit
def jit_function(M):
    return  M @ M.T

In [12]:
print("Numpy function:")
%timeit -n 50 numpy_function(M)
print("JAX function:")
%timeit -n 50 function(M)
print("JAX jit function:")
%timeit -n 50 jit_function(M)

Numpy function:
The slowest run took 5.10 times longer than the fastest. This could mean that an intermediate result is being cached.
6.11 ms ± 2.99 ms per loop (mean ± std. dev. of 7 runs, 50 loops each)
JAX function:
11.2 ms ± 1.08 ms per loop (mean ± std. dev. of 7 runs, 50 loops each)
JAX jit function:
4.06 ms ± 421 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)


In [None]:
# so how do we calculate gradients?



In [36]:
# dataset for linear regression

x = np.random.uniform(-1,1, (100, 10))
theta = np.random.uniform(-1,1, (10,2))
y = x @ theta
theta_ = np.random.uniform(-1,1, (10,2))


In [37]:
y.shape

(100, 2)

In [38]:
@jax.jit
def predict(x, theta):
    return x@theta

def mse(x, theta, y):
    y_ = predict(x, theta)
    return ((y-y_)**2).mean()



grad_func = jax.grad(mse, argnums=[1]) # returns a function which takes the same arguments as the wrapped one


# returns tuple of gradients with respect to arguments of function
grad, = grad_func(x, theta, y)

# for practicality, this is also available
loss, (grad,) = jax.value_and_grad(mse, argnums=[1])(x,theta_, y)
print(f"MSE loss: {loss}")

MSE loss: 1.5848617553710938


In [59]:
grad.shape

(10, 2)

In [60]:
# stupid loop to optimize our model
for _ in range(100):
    loss, (grad,) = jax.value_and_grad(mse, argnums=[1])(x,theta_, y)
    theta_ -= 0.01*grad 
    print(f"MSE: {loss}")

MSE: 0.0006293614860624075
MSE: 0.0006269514560699463
MSE: 0.0006245508557185531
MSE: 0.0006221593939699233
MSE: 0.000619777652900666
MSE: 0.000617404468357563
MSE: 0.000615040073171258
MSE: 0.0006126854568719864
MSE: 0.0006103403284214437
MSE: 0.0006080031162127852
MSE: 0.0006056762067601085
MSE: 0.000603357853833586
MSE: 0.0006010483484715223
MSE: 0.0005987479817122221
MSE: 0.0005964564625173807
MSE: 0.0005941731506027281
MSE: 0.0005918987444601953
MSE: 0.0005896340007893741
MSE: 0.0005873775808140635
MSE: 0.000585129193495959
MSE: 0.000582889246288687
MSE: 0.0005806598346680403
MSE: 0.000578437524382025
MSE: 0.0005762248183600605
MSE: 0.0005740205524489284
MSE: 0.0005718244356103241
MSE: 0.0005696367588825524
MSE: 0.0005674573476426303
MSE: 0.0005652862600982189
MSE: 0.0005631243111565709
MSE: 0.0005609701620414853
MSE: 0.0005588248022831976
MSE: 0.0005566873005591333
MSE: 0.0005545581225305796
MSE: 0.0005524371517822146
MSE: 0.000550324097275734
MSE: 0.0005482195992954075
MSE: 0.00

In [65]:
# what if we want to get a Jacobian?

J, = jax.jacobian(predict, argnums=[1])(x, theta)
J.shape

(100, 2, 10, 2)

In [68]:
# what about the Hessian?
H, = jax.jacfwd(jax.jacrev(predict), argnums=[1])(x, theta)
H.shape

(100, 2, 100, 10, 10, 2)

In [94]:
# vmap usage
# say we have a "complicated function" that we want to apply row-wise, ie. over axis=0

M = np.random.uniform(1, 10, (200, 2))
func = lambda x: x[0]**2 + jnp.exp(x[1])

@jax.jit
def naive(M):
    return jnp.stack([func(x) for x in M])

@jax.jit
def with_vmap(M):
    return jax.vmap(func)(M)

print("Naive:")
%timeit -n 50 naive(M)
print("With vmap:")
%timeit -n 50 with_vmap(M)


Naive:
The slowest run took 5783.31 times longer than the fastest. This could mean that an intermediate result is being cached.
12.3 ms ± 30 ms per loop (mean ± std. dev. of 7 runs, 50 loops each)
With vmap:
The slowest run took 235.43 times longer than the fastest. This could mean that an intermediate result is being cached.
310 µs ± 738 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)


In [169]:
# say we use an ensemble neural network (this cause a bit of pain for me and Sebastian to implement in PyTorch)
from jax.tree_util import tree_flatten, tree_unflatten



def tree_stack(trees):
    """Takes a list of trees and stacks every corresponding leaf.
    For example, given two trees ((a, b), c) and ((a', b'), c'), returns
    ((stack(a, a'), stack(b, b')), stack(c, c')).
    Useful for turning a list of objects into something you can feed to a
    vmapped function.
    """
    leaves_list = []
    treedef_list = []
    for tree in trees:
        leaves, treedef = tree_flatten(tree)
        leaves_list.append(leaves)
        treedef_list.append(treedef)

    grouped_leaves = zip(*leaves_list)
    result_leaves = [jnp.stack(l) for l in grouped_leaves]
    return treedef_list[0].unflatten(result_leaves)


def tree_unstack(tree):
    """Takes a tree and turns it into a list of trees. Inverse of tree_stack.
    For example, given a tree ((a, b), c), where a, b, and c all have first
    dimension k, will make k trees
    [((a[0], b[0]), c[0]), ..., ((a[k], b[k]), c[k])]
    Useful for turning the output of a vmapped function into normal objects.
    """
    leaves, treedef = tree_flatten(tree)
    n_trees = leaves[0].shape[0]
    new_leaves = [[] for _ in range(n_trees)]
    for leaf in leaves:
        for i in range(n_trees):
            new_leaves[i].append(leaf[i])
    new_trees = [treedef.unflatten(l) for l in new_leaves]
    return new_trees



def get_nn_params():
        return [
            (np.random.uniform(-1,1, (10, 64)), np.random.uniform(-1,1, (64, 1))),
            (np.random.uniform(-1,1, (64, 64)), np.random.uniform(-1,1, (64, 1))),
            (np.random.uniform(-1,1, (64, 2)),  np.random.uniform(-1,1, (2,1)))
        ]

def forward(x, theta):
    w, b =theta[0]
    x = jax.nn.relu(x@w) + b.T
    w, b =theta[0]
    x = jax.nn.relu(x@w) + b.T
    w, b =theta[0]
    x = jax.nn.relu(x@w) + b.T
    return x


params = get_nn_params()


out = forward(x, params)
out.shape



TracerIntegerConversionError: The __index__() method was called on the JAX Tracer object Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerIntegerConversionError

In [167]:
lots_of_params = [get_nn_params() for _ in range(100)]
# how do we parallelize this?

stacked_tree = tree_stack(lots_of_params)

x = np.random.uniform(-1,1, (512,10))


seq_ensemble_forward = lambda x,trees: [forward(x,tree) for tree in trees]
seq_ensemble_forward = jax.jit(seq_ensemble_forward)

vmap_ensemble_forward = jax.vmap(forward, in_axes=[None,[(0, 0), (0, 0), (0,0)]])
vmap_ensemble_forward = jax.jit(vmap_ensemble_forward)


seq_ensemble_forward = jax.jit(seq_ensemble_forward)

%timeit -n 50 seq_ensemble_forward(x, lots_of_params)
%timeit -n 50 vmap_ensemble_forward(x, stacked_tree)



The slowest run took 13.71 times longer than the fastest. This could mean that an intermediate result is being cached.
46.5 ms ± 73.1 ms per loop (mean ± std. dev. of 7 runs, 50 loops each)
21.1 ms ± 1.15 ms per loop (mean ± std. dev. of 7 runs, 50 loops each)


In [None]:
# JAX magic functions

# dataset
x = np.random.uniform(-1,1, (100, 10))
theta = np.random.uniform(-1,1, (10,1))
theta_ = np.random.uniform(-1,1, (10,1))




y = x @ theta # ground truth
y_ = x @ theta_ # estimates

# mse error
def mse(y_,y):
    pass
    





In [5]:
# numpy and JAX syntax is similar in most cases, except one crucial annoying one...

# random seed initialization
random_seed = 123
key = jax.random.PRNGKey(random_seed) # returns 

In [None]:
# to generate random numbers, we need to iteratively 