Source: https://colinraffel.com/blog/you-don-t-know-jax.html

Goal of paper: Learn the XOR function with a neural network using JAX

import random
import itertools

import jax
import jax.numpy as np

% Current convention is to import original numpy as "onp"

import numpy as onp

from __future__ import print_function

In [2]:
import random, itertools
import jax, jax.numpy as jnp
import numpy as np

from __future__ import print_function

In [3]:
# sigmoid nonlinear activation function
def sigmoid(x): 
    ''' Sigmoid activation function , i.e. our nonlinearity '''
    return 1 / (1 + jnp.exp(-x))

def net(params, x):
    ''' Simple neural network with two hidden layers '''
    w1, b1, w2, b2 = params
        
    # first hidden layer with tanh activation function
    h1 = jnp.tanh(jnp.dot(w1, x) + b1)
        
    # second hidden layer with tanh activation function
    h2 = jnp.tanh(jnp.dot(w2, h1) + b2)
    
    # sigmoid activation for final output, which will be given 
    # to loss fcn to evaluate model's learning
    return sigmoid(h2)

def loss(params, x, y):
    ''' Compute the loss, in this case cross entropy '''
    # get output from network
    output = net(params, x)
    
    # compute binary cross entropy between output and ideal output y
    cross_entropy = -y * jnp.log(output) - (1 - y) * jnp.log(1 - output)
    
    return cross_entropy

def test_all_inputs(inputs, params):
    ''' Utility fcn for testing all inputs '''
    # get predictions from network
    # do this by squashing output into either 0 or 1 from model
    predictions = [int(net(params, inp) > 0.5) for inp in inputs]
        
    for inp, out in zip(inputs, predictions):
        print(inp, '->', out)
    
    # test whether the predictions actually math the XOR
    return (predictions == [np.bitwise_xor(*inp) for inp in inputs])

def initial_params():
    return [
        np.random.randn(3, 2),  # w1
        np.random.randn(3),  # b1
        np.random.randn(3),  # w2
        np.random.randn(),  # b2
    ]

Use jax.grad to evaluate the gradients of the loss with respect to the networks parameters. Recall how the computational graph changes

In [None]:
# create a fcn which evaluates the gradient, in this case it will be loss
loss_grad = jax.grad(loss)

# start network parameters off in random configuration
params = initial_params()

learning_rate = 1.

xor_inputs = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]])

for n in itertools.count():
    # input to the network for training
    x = xor_inputs[np.random.choice(xor_inputs.shape[0])] # (2,)
    
    # ideal output that will be used for training, i.e. cross entropy
    y = np.bitwise_xor(*x) # (,)
    
    # compute gradients, and use for updating
    grads = loss_grad(params, x, y)
    
    # update params based on found gradients
    params = [param - learning_rate * grad for param, grad in zip(params, grads)]
    
    # every 100 iterations, check to see if we've solved XOR
    if not n % 100:
        #print(f'iteration: {n}')
        if test_all_inputs(xor_inputs, params):
            print(f'n: {n}')
            break
            
    if n == 3000:
        break

Now let's implement jax.jit to get the speed ups we want for using JAX, that's the whole point as well as their out-of-the-box methods that I'll be using

In [None]:
# time the non-jit jax.grad
%timeit loss_grad(params, x, y)

# jit compile loss fcn
loss_grad = jax.jit(jax.grad(loss))

# now call the jit version of loss_grad
%timeit loss_grad(params, x, y)

3 orders of magnitude speed up is ... a lot!

Now let's try training the model again and see what happens


In [4]:
params = initial_params()
inputs = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]])
learning_rate = 1.
loss_grad = jax.jit(jax.grad(loss))

for n in itertools.count():
    x = inputs[np.random.choice(inputs.shape[0])]
    y = np.bitwise_xor(*x)
    grads = loss_grad(params, x, y)
    params = [param - learning_rate * grad
              for param, grad in zip(params, grads)]
    if not n % 100:
        print('Iteration {}'.format(n))
        if test_all_inputs(inputs, params):
            break

Iteration 0
[0 0] -> 0
[0 1] -> 0
[1 0] -> 0
[1 1] -> 0
Iteration 100
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
Iteration 200
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
Iteration 300
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
Iteration 400
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
Iteration 500
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
Iteration 600
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
Iteration 700
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
Iteration 800
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
Iteration 900
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
Iteration 1000
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
Iteration 1100
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
Iteration 1200
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
Iteration 1300
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
Iteration 1400
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
Iteration 1500
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
Iteration 1600
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
Iteration

Iteration 14000
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
Iteration 14100
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
Iteration 14200
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
Iteration 14300
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
Iteration 14400
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
Iteration 14500
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
Iteration 14600
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
Iteration 14700
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
Iteration 14800
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
Iteration 14900
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
Iteration 15000
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
Iteration 15100
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
Iteration 15200
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
Iteration 15300
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
Iteration 15400
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
Iteration 15500
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
Iteration 15600
[0 0] -> 0
[0 1] -> 0
[1

Exception ignored in: <function _xla_gc_callback at 0x7a1a18bab130>
Traceback (most recent call last):
  File "/home/czp/py/lib/python3.10/site-packages/jax/_src/lib/__init__.py", line 118, in _xla_gc_callback
    def _xla_gc_callback(*args):
KeyboardInterrupt: 


KeyboardInterrupt: 