In [121]:
from __future__ import print_function

import itertools

import jax
from jax import jit
import jax.numpy as np
from jax.experimental import stax, optimizers
from jax.experimental.stax import Dense, elementwise

import numpy as onp

import sys
import os
sys.path.insert(0, os.path.abspath('..'))
from activations import sigmoid
from losses import create_loss, crossentropy as cse


In [122]:
Tanh = elementwise(np.tanh)
Sigmoid = elementwise(sigmoid)

init_random_params, net = stax.serial(
    Dense(3), Tanh,
    Dense(1), Sigmoid)

loss = create_loss(net, cse)

In [123]:
def test_all_inputs(inputs, params):
    """Tests all possible xor inputs and outputs"""
    print(net(params, inputs).ndim)
    predictions = [int(net(params, inp) > 0.5) for inp in inputs]
    for inp, out in zip(inputs, predictions):
        print(inp, '->', out)
    return (predictions == [onp.bitwise_xor(*inp) for inp in inputs])

In [124]:
# loss_grad = jax.jit(jax.grad(loss))
loss_grad = jax.jit(jax.vmap(jax.grad(loss), in_axes=(None, 0, 0), out_axes=-1))

In [125]:
@jit
def update(i, opt_state, batch):
    params = get_params(opt_state)
    x, y = batch
    g = loss_grad(params, x, y)
    g = [np.mean(grad, axis=-1) for grads in g for grad in grads]
    return opt_update(i, g, opt_state)

In [126]:
rng = jax.random.PRNGKey(0)

inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])

opt_init, opt_update, get_params = optimizers.sgd(0.5)
_, init_params = init_random_params(rng, (-1, 2))
opt_state = opt_init(init_params)
itercount = itertools.count()

In [127]:
get_params(opt_state)

[(DeviceArray([[ 0.37440315, -0.94589597,  0.47653025],
               [ 0.3341504 , -0.57999146, -0.74206877]], dtype=float32),
  DeviceArray([ 0.00156939,  0.01631498, -0.00894601], dtype=float32)),
 (),
 (DeviceArray([[ 0.52601111],
               [-0.35282513],
               [-0.82455486]], dtype=float32),
  DeviceArray([0.01620002], dtype=float32)),
 ()]

In [130]:
x = inputs[onp.random.choice(inputs.shape[0], size=4)]
y = onp.bitwise_xor(x[:, 0], x[:, 1])
params = get_params(opt_state)
g = loss_grad(params, x, y)
g = [np.mean(grad, axis=-1) for grads in g for grad in grads]

In [131]:
g

[DeviceArray([[-0.05431693,  0.0195148 ,  0.07915428],
              [-0.0342512 ,  0.01896921,  0.03573295]], dtype=float32),
 DeviceArray([ 0.04416027, -0.05052089, -0.09315658], dtype=float32),
 DeviceArray([[-0.06573147],
              [ 0.12785722],
              [-0.00773748]], dtype=float32),
 DeviceArray([0.0610374], dtype=float32)]

In [129]:
print("\nStarting training...")

for n in itertools.count():
    # Grab a single random input
    x = inputs[onp.random.choice(inputs.shape[0], size=4)]
    y = onp.bitwise_xor(x[:, 0], x[:, 1])
    batch = (x, y)

    opt_state = update(next(itercount), opt_state, batch)

    params = get_params(opt_state)
    # Every 100 iterations, check whether we've solved XOR
    if not n % 100:
        print('Iteration {}'.format(n))
        if test_all_inputs(inputs, get_params(opt_state)):
            break


Starting training...


TypeError: optimizer update function was passed a gradient tree that did not match the parameter tree structure with which it was initialized: parameter tree PyTreeDef(list, [PyTreeDef(tuple, [*,*]),PyTreeDef(tuple, []),PyTreeDef(tuple, [*,*]),PyTreeDef(tuple, [])]) and grad tree PyTreeDef(list, [*,*,*,*]).

In [11]:
get_params(opt_state)

[(DeviceArray([[[ 0.25790906, -0.9237622 ,  0.7468996 ],
                [ 0.21765631, -0.55785769, -0.47169942]],
  
               [[ 0.37440315, -0.94589597,  0.47653025],
                [ 0.3341504 , -0.57999146, -0.74206877]],
  
               [[ 0.37440315, -0.94589597,  0.47653025],
                [ 0.4026528 , -0.61792988, -0.81353468]],
  
               [[ 0.37440315, -0.94589597,  0.47653025],
                [ 0.3341504 , -0.57999146, -0.74206877]]], dtype=float32),
  DeviceArray([[-0.11492471,  0.03844877,  0.26142335],
               [-0.13115901,  0.10531988,  0.19909781],
               [ 0.07007179, -0.02162345, -0.08041191],
               [-0.13115901,  0.10531988,  0.19909781]], dtype=float32)),
 (),
 (DeviceArray([[[ 0.3102755 ],
                [-0.03249529],
                [-0.72996289]],
  
               [[ 0.5256151 ],
                [-0.35694155],
                [-0.82229757]],
  
               [[ 0.57309186],
                [-0.4271155 ],
           