<a href="https://colab.research.google.com/github/maxmatical/jax_projects/blob/master/01_jax_intro.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import jax
import jax.numpy as np
import numpy as onp


# Training XOR in Jax with nn with 1 hidden layer

In [5]:
inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
print(inputs.shape)

(4, 2)


In [0]:
def sigmoid(x):
    return 1/(1+np.exp(-x))

def net(params, x):
    w1, b1, w2, b2 = params # weights and 
    h1 = np.maximum(np.dot(w1, x)+ b1, 0) # matmul then relu (which is basically max(0, out))

    out = sigmoid(np.dot(w2, h1) + b2)
    return out

# cross entropy loss
def loss(params, x, y):
    out = net(params, x)
    cross_entropy = -y*np.log(out) - (1-y)*np.log(1-out)
    return cross_entropy


# Utility function for testing whether the net produces the correct
# output for all possible inputs
def test_all_inputs(inputs, params):
    predictions = [int(net(params, inp) > 0.5) for inp in inputs]
    for inp, out in zip(inputs, predictions):
        print(inp, '->', out)
    return (predictions == [onp.bitwise_xor(*inp) for inp in inputs])


Initialize weights

note: done in onp because it's easier than using jax's rng

x shape = [n, ni]

w = [nh, ni]

np.dot(w, x) -> $x^{t}w$ gives shape [n, nh]

In [0]:
# ??onp.random.randn

In [0]:
def init_params():
    # returns samples from a gaussian distribution
    return [
        onp.random.randn(3, 2),  # w1 shape = [nh, ni], inpu
        onp.random.randn(3),  # b1
        onp.random.randn(3),  # w2
        onp.random.randn(),  #b2
    ]


In [0]:
# compute gradients
loss_grad = jax.grad(loss) # that was easy

start training

In [62]:
%%time
params = init_params()
lr = 1e-1
for i in range(1000):
    x = inputs[onp.random.choice(inputs.shape[0])] # grab 1 sample
    y = onp.bitwise_xor(*x) # get the output
    grads = loss_grad(params, x, y)
    #update gradients
    params = [p-lr*g for p, g in zip(params, grads)]

    # print stats
    if (i+1)%100 == 0:
        print(f'Iteration {i+1}')
        print(test_all_inputs(inputs, params))


Iteration 100
[0 0] -> 1
[0 1] -> 1
[1 0] -> 1
[1 1] -> 1
False
Iteration 200
[0 0] -> 1
[0 1] -> 1
[1 0] -> 1
[1 1] -> 1
False
Iteration 300
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
False
Iteration 400
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
False
Iteration 500
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
False
Iteration 600
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
False
Iteration 700
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
False
Iteration 800
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
False
Iteration 900
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
False
Iteration 1000
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
False
CPU times: user 18 s, sys: 3.45 s, total: 21.4 s
Wall time: 16.8 s


## Using jit compile with  `jax.jit`

In [0]:
loss_grad = jax.jit(jax.grad(loss)) # just add jax.jit in front of it

In [64]:
%%time
params = init_params()
lr = 1e-1
for i in range(1000):
    x = inputs[onp.random.choice(inputs.shape[0])] # grab 1 sample
    y = onp.bitwise_xor(*x) # get the output
    grads = loss_grad(params, x, y)
    #update gradients
    params = [p-lr*g for p, g in zip(params, grads)]

    # print stats
    if (i+1)%100 == 0:
        print(f'Iteration {i+1}')
        print(test_all_inputs(inputs, params))

Iteration 100
[0 0] -> 0
[0 1] -> 0
[1 0] -> 0
[1 1] -> 0
False
Iteration 200
[0 0] -> 1
[0 1] -> 1
[1 0] -> 0
[1 1] -> 1
False
Iteration 300
[0 0] -> 0
[0 1] -> 0
[1 0] -> 0
[1 1] -> 0
False
Iteration 400
[0 0] -> 1
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
False
Iteration 500
[0 0] -> 1
[0 1] -> 1
[1 0] -> 1
[1 1] -> 1
False
Iteration 600
[0 0] -> 0
[0 1] -> 1
[1 0] -> 0
[1 1] -> 1
False
Iteration 700
[0 0] -> 0
[0 1] -> 1
[1 0] -> 1
[1 1] -> 1
False
Iteration 800
[0 0] -> 0
[0 1] -> 1
[1 0] -> 1
[1 1] -> 1
False
Iteration 900
[0 0] -> 0
[0 1] -> 1
[1 0] -> 1
[1 1] -> 1
False
Iteration 1000
[0 0] -> 0
[0 1] -> 0
[1 0] -> 0
[1 1] -> 0
False
CPU times: user 4.53 s, sys: 1.06 s, total: 5.6 s
Wall time: 4.29 s


## batching using `jax.vmap`

for `jax.vmap`:

`in_axes`: what axes of the input to parallelize over
- using `(None, 0, 0)` to for `(params, x, y)` since we want to parallelize the 0th dim for x, y, but not parallelize for params

`out_axes`: what axes of the output to parallelize over
- 0 because 0th dim for output (the gradients)

In [0]:
loss_grad = jax.jit(jax.vmap(jax.grad(loss), in_axes =(None, 0, 0), out_axes=0))

In [74]:
%%time
params = init_params()
lr = 1e-1
bs = 128
for i in range(1000):
    x = inputs[onp.random.choice(inputs.shape[0], size = bs)] # grab 1 sample
    y = onp.bitwise_xor(x[:, 0], x[:, 1]) # getting y for a batch of x's
    grads = loss_grad(params, x, y)
    #update gradients
    params = [p-lr*np.mean(g, axis=0)
        for p, g, in zip(params, grads)] # grabbing the mean of loss across a batch

    # print stats
    if (i+1)%100 == 0:
        print(f'Iteration {i+1}')
        print(test_all_inputs(inputs, params))

Iteration 100
[0 0] -> 1
[0 1] -> 1
[1 0] -> 1
[1 1] -> 0
False
Iteration 200
[0 0] -> 1
[0 1] -> 1
[1 0] -> 1
[1 1] -> 0
False
Iteration 300
[0 0] -> 1
[0 1] -> 1
[1 0] -> 1
[1 1] -> 0
False
Iteration 400
[0 0] -> 1
[0 1] -> 1
[1 0] -> 1
[1 1] -> 0
False
Iteration 500
[0 0] -> 1
[0 1] -> 1
[1 0] -> 1
[1 1] -> 0
False
Iteration 600
[0 0] -> 1
[0 1] -> 1
[1 0] -> 1
[1 1] -> 0
False
Iteration 700
[0 0] -> 1
[0 1] -> 1
[1 0] -> 1
[1 1] -> 0
False
Iteration 800
[0 0] -> 1
[0 1] -> 1
[1 0] -> 1
[1 1] -> 0
False
Iteration 900
[0 0] -> 1
[0 1] -> 1
[1 0] -> 1
[1 1] -> 0
False
Iteration 1000
[0 0] -> 1
[0 1] -> 1
[1 0] -> 1
[1 1] -> 0
False
CPU times: user 10.8 s, sys: 2.68 s, total: 13.5 s
Wall time: 10.4 s
