In [1]:
%load_ext autoreload
%autoreload 2
from nram import *
from opt import *
from gate import *
from data import *
from nnet import *

In [2]:
import numpy as np
import theano

In [317]:
theano.config.optimizer = 'fast_run'

In [719]:
num_registers = 2
num_timesteps = 2
max_int = 10
gates = [one, zero, write, read]
gate_names = ["1", "0", "Write", "Read"]

In [774]:
## Goal task: delete the first character from a null-terminated string.

## Task 1: Check whether the first character is a null-terminator.
## Task 3: Shift all characters over until the null terminator
def task_one(max_int, batch_size):
    # Create a random initial memory.
    init_mem = np.random.randint(0, max_int, size=(batch_size, max_int), dtype=np.int32)
    
    # Seed half of it with input zeros.
    seed = np.random.choice([0, 1], size=(batch_size, 1), p=[0.44, 0.56])
    init_mem[:, 0] *= seed[:, 0]
    
    # Generate an output.
    out_mem = init_mem.copy()
    out_mem[:, 1] = np.where(init_mem[:, 0] == 0, 1, 0)
    
    return init_mem, out_mem

## Task 2: Find the null terminator.
def task_two(max_int, batch_size):
    # Create a random initial memory.
    init_mem = np.random.randint(1, max_int, size=(batch_size, max_int), dtype=np.int32)
    
    # Choose where to put the zeros.
    locs = np.random.choice(range(max_int - 1), size=(batch_size, ))
    init_mem[np.arange(batch_size), locs] = 0
    
    # Generate an output.
    out_mem = init_mem.copy()
    out_mem[np.arange(batch_size), locs + 1] = 1
    
    return init_mem, out_mem

In [775]:
def generate_batch(max_int, batch_size=1000):
    def make_batch(timestep):
        init_mem, out_mem = task_two(max_int, batch_size)
        cost_mask = np.ones_like(out_mem, dtype=np.int8)

        return encode(init_mem, max_int), out_mem, cost_mask
    return make_batch

In [511]:
def generate_batch(max_int, batch_size=1000):
    def make_batch(timestep):
        # Random initial memory
        init_mem = np.random.randint(1, max_int, size=(batch_size, max_int),
                                     dtype=np.int32)
        # decide?
        one_zero = np.random.randint(0, 2, size=(batch_size,), dtype=np.int32) * 3
        init_mem[np.arange(batch_size), 0] = one_zero
        cost_mask = np.ones((batch_size, max_int), dtype=np.int8)
        out_mem = init_mem.copy()
        out_mem[np.arange(batch_size), 1] = np.where(one_zero == 0, 1, 2)

        return encode(init_mem, max_int), out_mem, cost_mask
    return make_batch

In [776]:
layer_sizes = []
params = list(mlp_weights(num_registers, layer_sizes, gates))

In [777]:
reg_lambda = 0
result = run(gates, num_registers, max_int, num_timesteps, len(layer_sizes), reg_lambda, params)
debug, init_mem, desired_mem, cost_mask, final_mem, final_cost = result

In [778]:
gradients = theano.grad(final_cost, params)
train = theano.function([init_mem, desired_mem, cost_mask], [final_cost] + gradients)
predict = theano.function([init_mem], final_mem)

keys = list(debug.keys())
values = [debug[k] for k in keys]
predict_instrumented = theano.function([init_mem, desired_mem, cost_mask], values)
def predict_debug(*args):
    return dict(zip(keys, predict_instrumented(*args)))

In [None]:
adam_optimize(params, generate_batch(max_int), train)

Cost (t =    0): 	166043.46
Cost (t =  100): 	162347.83
Cost (t =  200): 	168843.07


In [717]:
b = generate_batch(max_int)
inputs, outputs, mask = b(1500)
percent_correct(predict, inputs, outputs)

100.0

In [718]:
r = predict_debug(inputs, outputs, mask)

In [732]:
def inspect(debug, sample, num_timesteps, num_registers, gate_names, gates):
    """Utility for inspecting the result of the network."""
    
    def get(name, timestep):
        return debug["%d:%s" % (timestep, name)][sample, :].argmax()
    
    def fmt(i):
        if i in range(num_registers):
            return "R" + str(i + 1)
        else:
            return "G" + str(i - num_registers)
        
    output = ""
    output += "Init: %s\n" % debug["0:gate-mem-0"][sample, :, :].argmax(axis=1)
    for timestep in range(num_timesteps):
        output += "Timestep %d:\n" % timestep
        for r in range(num_registers):
            src = fmt(get("coeff-reg-%d" % r, timestep))
            val = get("reg-%d" % r, timestep)
            output += "\tR%d' = %s\t\t%d\n" % (r, src, val)
        for g, (name, gate) in enumerate(zip(gate_names, gates)):
            src = ", ".join(fmt(get("coeff-gate-%d/%d" % (g, a), timestep)) for a in range(gate.arity))
            val = get("gate-out-%d" % g, timestep)
            output += "\tG%d' = %s(%s)\t\t%d\n" % (g, name, src, val)
        output += "\tComplete -> %.3f\n" % debug["%d:complete" % timestep][sample, 0]
        output += "\tMem -> %s\n" % str(debug["%d:gate-mem-%d" % (timestep, g)][sample, :, :].argmax(axis=1))
    return output

In [759]:
x += 1
print(inspect(r, x, num_timesteps, num_registers, gate_names, gates))

Init: [4 8 8 2 2 7 2 5 9 8]
Timestep 0:
	R0' = G1		0
	R1' = G3		4
	G0' = 1()		1
	G1' = 0()		0
	G2' = Write(G0, G0)		1
	G3' = Read(R2)		4
	Complete -> 0.000
	Mem -> [4 1 8 2 2 7 2 5 9 8]
Timestep 1:
	R0' = G1		0
	R1' = G3		2
	G0' = 1()		1
	G1' = 0()		0
	G2' = Write(G0, G1)		0
	G3' = Read(R2)		2
	Complete -> 0.000
	Mem -> [4 0 8 2 2 7 2 5 9 8]

