In [1]:
# Francisco Dominguez Mateos
# 24/06/2020
# From: https://colinraffel.com/blog/you-don-t-know-jax.html

In [2]:
import random
import itertools

import jax
import jax.numpy as np
# Current convention is to import original numpy as "onp"
import numpy as onp

In [3]:
# Sigmoid nonlinearity
def sigmoid(x):
    return 1 / (1 + np.exp(-x))

# Computes our network's output
def net(params, x):
    w1, b1, w2, b2 = params
    l1 = np.dot(w1, x) + b1
    a1 = np.tanh(l1)
    l2 = np.dot(w2, a1) + b2
    return sigmoid(l2)

# Cross-entropy loss
def loss(params, x, y):
    pred = net(params, x)
    cross_entropy = -y * np.log(pred) - (1 - y)*np.log(1 - pred)
    return cross_entropy

# Utility function for testing whether the net produces the correct
# output for all possible inputs
def test_all_inputs(inputs, params):
    predictions = [int(net(params, inp) > 0.5) for inp in inputs]
    for inp, out in zip(inputs, predictions):
        print(inp, '->', out)

In [4]:
def initial_params():
    return [
        onp.random.randn(3, 2),  # w1
        onp.random.randn(3),     # b1
        onp.random.randn(3),     # w2
        onp.random.randn(),      #b2
    ]

In [7]:
loss_grad = jax.grad(loss)

# Stochastic gradient descent learning rate
learning_rate = 1.
# All possible inputs
inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])

# Initialize parameters randomly
params = initial_params()

#for n in itertools.count():
for n in range(300):
    # Grab a single random input
    x = inputs[onp.random.choice(inputs.shape[0])]
    # Compute the target output
    y = onp.bitwise_xor(*x)
    # Get the gradient of the loss for this input/output pair
    grads = loss_grad(params, x, y)
    # Update parameters via gradient descent
    params = [param - learning_rate * grad
              for param, grad in zip(params, grads)]
    # Every 100 iterations, check whether we've solved XOR
    if not n % 100:
        print('Iteration {}'.format(n))
        if test_all_inputs(inputs, params):
            break

Iteration 0
[0 0] -> 0
[0 1] -> 0
[1 0] -> 0
[1 1] -> 0
Iteration 100
[0 0] -> 0
[0 1] -> 1
[1 0] -> 1
[1 1] -> 1
Iteration 200
[0 0] -> 0
[0 1] -> 1
[1 0] -> 1
[1 1] -> 0


In [8]:
# Time the original gradient function
%timeit loss_grad(params, x, y)
loss_grad = jax.jit(jax.grad(loss))
# Run once to trigger JIT compilation
loss_grad(params, x, y)
%timeit loss_grad(params, x, y)

7.54 ms ± 168 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
226 µs ± 2.67 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [9]:
params = initial_params()

#for n in itertools.count():
for n in range(300):
    x = inputs[onp.random.choice(inputs.shape[0])]
    y = onp.bitwise_xor(*x)
    grads = loss_grad(params, x, y)
    params = [param - learning_rate * grad
              for param, grad in zip(params, grads)]
    if not n % 100:
        print('Iteration {}'.format(n))
        if test_all_inputs(inputs, params):
            break

Iteration 0
[0 0] -> 1
[0 1] -> 1
[1 0] -> 0
[1 1] -> 0
Iteration 100
[0 0] -> 0
[0 1] -> 1
[1 0] -> 1
[1 1] -> 0
Iteration 200
[0 0] -> 0
[0 1] -> 1
[1 0] -> 1
[1 1] -> 0


In [None]:
loss_grad = jax.jit(jax.vmap(jax.grad(loss), in_axes=(None, 0, 0), out_axes=0))

params = initial_params()

batch_size = 100

for n in itertools.count():
    # Generate a batch of inputs
    x = inputs[onp.random.choice(inputs.shape[0], size=batch_size)]
    y = onp.bitwise_xor(x[:, 0], x[:, 1])
    # The call to loss_grad remains the same!
    grads = loss_grad(params, x, y)
    # Note that we now need to average gradients over the batch
    params = [param - learning_rate * np.mean(grad, axis=0)
              for param, grad in zip(params, grads)]
    if not n % 100:
        print('Iteration {}'.format(n))
        if test_all_inputs(inputs, params):
            break