In [1]:
from __future__ import print_function

import itertools

import jax
from jax import jit, grad, vmap
import jax.numpy as np

import numpy as onp

import sys
import os
sys.path.insert(0, os.path.abspath('..'))
from fastax import optimizers
from fastax.layers import Dense, elementwise, serial
from fastax.activations import sigmoid
from fastax.losses import batch_loss, crossentropy as cse

In [2]:
Tanh = elementwise(np.tanh)
Sigmoid = elementwise(sigmoid)

init_random_params, net = serial(
    Dense(3), Tanh,
    Dense(1), Sigmoid)

loss = batch_loss(net)(cse)



In [3]:
@jit
def predict(params, inp):
    return np.asarray(np.round(net(params, inp)[0]), dtype=np.uint32)

def test_all_inputs(inputs, params):
    """Tests all possible xor inputs and outputs"""
    predictions = vmap(predict, in_axes=(None, 0))(params, inputs)
    for inp, out in zip(inputs, predictions):
        print(inp, '->', out)
    return (predictions == np.asarray([onp.bitwise_xor(*inp) for inp in inputs])).all()

In [4]:
loss_grad = jit(grad(loss))

In [5]:
@jit
def update(i, opt_state, batch):
    params = get_params(opt_state)
    x, y = batch
    grads = loss_grad(params, x, y)
    return opt_update(i, grads, opt_state)

In [6]:
rng = jax.random.PRNGKey(0)

inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])

opt_init, opt_update, get_params = optimizers.sgd(0.5)
_, init_params = init_random_params(rng, (-1, 2))
opt_state = opt_init(init_params)
itercount = itertools.count()

In [7]:
print("\nStarting training...")

for n in itertools.count():
    x = inputs[onp.random.choice(inputs.shape[0], size=100)]
    y = onp.bitwise_xor(x[:, 0], x[:, 1])
    batch = (x, y)

    opt_state = update(next(itercount), opt_state, batch)

    params = get_params(opt_state)
    # Every 100 iterations, check whether we've solved XOR
    if not n % 100:
        print('Iteration {}'.format(n))
        if test_all_inputs(inputs, get_params(opt_state)):
            break


Starting training...
Iteration 0
[0 0] -> 1
[0 1] -> 1
[1 0] -> 1
[1 1] -> 1
Iteration 100
[0 0] -> 0
[0 1] -> 1
[1 0] -> 0
[1 1] -> 0
Iteration 200
[0 0] -> 0
[0 1] -> 1
[1 0] -> 1
[1 1] -> 0
