In [None]:
import nengo
import numpy as np
%matplotlib inline
import matplotlib.pyplot as plt

import nengo_deeplearning.backends.theano.optimizers as opt
import nengo_deeplearning.processes as proc
from nengo_deeplearning.backends.theano.networks import RNN
from nengo_deeplearning.backends.theano.layers import Generic, GatedRecurrent, Dense
from nengo_deeplearning.inputs.mnist import load_mnist

# Need to have 10 separate colors; from http://colorbrewer2.org/
plt.rc('axes', color_cycle=[
    "#a6cee3",
    "#1f78b4",
    "#b2df8a",
    "#33a02c",
    "#fb9a99",
    "#e31a1c",
    "#fdbf6f",
    "#ff7f00",
    "#cab2d6",
    "#6a3d9a",
])

## Train an RNN to classify MNIST

In [None]:
trX, teX, trY, teY = load_mnist(data_dir='.')  # Will download to data_dir if they don't exist

# RNN processes a size 28 vector at a time scanning from left to right 
layers = [
    Generic(size=28),
    GatedRecurrent(size=512, p_drop=0.2),
    Dense(size=10, activation='softmax', p_drop=0.5)
]

# A bit of l2 helps with generalization, higher momentum helps convergence
optimizer = opt.NAG(momentum=0.95, regularizer=opt.Regularizer(l2=1e-4))

# Linear iterator for real valued data, cce cost for softmax
model = RNN(layers=layers, optimizer=optimizer, iterator='linear', cost='cce')
model.fit(trX, trY, n_epochs=10)

tr_preds = model.predict(trX[:len(teY)])
te_preds = model.predict(teX)

tr_acc = np.mean(np.argmax(trY[:len(teY)], axis=1) == np.argmax(tr_preds, axis=1))
te_acc = np.mean(np.argmax(teY, axis=1) == np.argmax(te_preds, axis=1))

print "  ====== Results ======"
print "Train accuracy %s\tTest accuracy %s" % (tr_acc, te_acc)
model.save('trained_rnn.pkl')

## Integrating with Nengo

In [None]:
trX, teX, trY, teY = load_mnist(data_dir='.')
trX = trX.reshape((-1, 28))  # Reshape so that each input is one row
dl_net = RNN.load('trained_rnn.pkl')

with nengo.Network() as net:
    mnist_in = nengo.Node(proc.PresentInput(trX, presentation_time=0.01), size_out=28)
    rnn = nengo.Node(proc.DLBlackBox(dl_net, history=28), size_in=28, size_out=10)
    nengo.Connection(mnist_in, rnn, synapse=None)
    pr_in = nengo.Probe(mnist_in)
    pr_out = nengo.Probe(rnn)

In [None]:
# NB! dt must match presentation time for RNN
sim = nengo.Simulator(net, dt=0.01)
sim.run(2.0)

plt.figure(figsize=(14, 6))
plt.subplot(2, 1, 1)
plt.pcolormesh(sim.trange(), np.arange(28), sim.data[pr_in].T,
               cmap=plt.cm.binary)
plt.ylim(0, 28)
plt.subplot(2, 1, 2)
plt.plot(sim.trange(), sim.data[pr_out])
plt.legend(np.arange(10), loc='best', fontsize='small')
# Digits are rotated left 90 degrees, but you get the idea