In [1]:
from __future__ import print_function

import itertools

import jax
from jax import jit
import jax.numpy as np
from jax.experimental import stax, optimizers
from jax.experimental.stax import Dense, elementwise

import numpy as onp

import sys
import os
sys.path.insert(0, os.path.abspath('..'))
from activations import sigmoid
from losses import create_loss, crossentropy as cse


In [2]:
Tanh = elementwise(np.tanh)
Sigmoid = elementwise(sigmoid)

init_random_params, net = stax.serial(
    Dense(3), Tanh,
    Dense(1), Sigmoid)

loss = create_loss(net, cse)

In [3]:
def test_all_inputs(inputs, params):
    """Tests all possible xor inputs and outputs"""
    predictions = [int(net(params, inp) > 0.5) for inp in inputs]
    for inp, out in zip(inputs, predictions):
        print(inp, '->', out)
    return (predictions == [onp.bitwise_xor(*inp) for inp in inputs])

In [4]:
loss_grad = jax.jit(jax.vmap(jax.grad(loss), in_axes=(None, 0, 0), out_axes=-1))

In [5]:
@jit
def update(i, opt_state, batch):
    params = get_params(opt_state)
    x, y = batch
    all_grads = loss_grad(params, x, y)
    
    for i, grads in enumerate(all_grads):
        if len(grads) > 0:
            all_grads[i] = tuple(np.mean(g, axis=-1) for g in grads)
        else:
            all_grads[i] = ()
    return opt_update(i, all_grads, opt_state)

In [7]:
rng = jax.random.PRNGKey(0)

inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])

opt_init, opt_update, get_params = optimizers.sgd(0.5)
_, init_params = init_random_params(rng, (-1, 2))
print(_, params.shape)
opt_state = opt_init(init_params)
itercount = itertools.count()

NameError: name 'params' is not defined

In [None]:
print("\nStarting training...")

for n in itertools.count():
    x = inputs[onp.random.choice(inputs.shape[0], size=100)]
    y = onp.bitwise_xor(x[:, 0], x[:, 1])
    batch = (x, y)

    opt_state = update(next(itercount), opt_state, batch)

    params = get_params(opt_state)
    # Every 100 iterations, check whether we've solved XOR
    if not n % 100:
        print('Iteration {}'.format(n))
        if test_all_inputs(inputs, get_params(opt_state)):
            break