# Using a TensorNode class to insert a network built with the Tensorflow Keras API into a Nengo model
This tutorial builds on [Inserting a TensorFlow network into a Nengo model](./pretrained-model.ipynb). In that tutorial, we showed how to write a TensorNode class that makes it easy to insert a pre-trained network from `tf.contrib.slim` into a Nengo model.  
Sometimes, instead of using a pre-trained network, you may want to use your own neural net architecture within a TensorNode. Here we show how a network built with the `tf.keras` API can be insterted into a Nengo model using the TensorNode class. If you haven't read the previous tutorial, please work through that one to familiarize yourself with the concept we'll use here. This tutorial also assumes familiarity with the `tf.keras` API. Specifically it is based on the [introduction in the Tensorflow documentation](https://www.tensorflow.org/tutorials/keras/basic_classification), so if you are not yet familiar with Keras, you may find it helpful to read those tutorials first as well.

In [1]:
import numpy as np
import tensorflow as tf
from tensorflow import keras

import nengo
import nengo_dl

We'll train a neural network to classify the fashion MNIST dataset.

In [58]:
fashion_mnist = keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
num_classes = np.unique(test_labels).shape[0]  # there's 10, just like MNIST

# normalize images so values are between 0 and 1
train_images = train_images / 255.0
test_images = test_images / 255.0

First we build and train a very simple fully-connected neural network, using the Keras API to work with Tensorflow.

In [59]:
model = keras.Sequential([
            keras.layers.Flatten(input_shape=(28, 28), name='flatten'),
            keras.layers.Dense(128, activation=tf.nn.relu, name='hidden'),
            keras.layers.Dense(num_classes, activation=tf.nn.softmax, name='softmax')
        ])

model.compile(optimizer=tf.train.AdamOptimizer(), 
                      loss='sparse_categorical_crossentropy',
                      metrics=['accuracy'])

In [60]:
model.fit(train_images, train_labels, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x7fe8f59a03c8>

We save the trained weights. This way, as in the previous tutorial, we can load them in the `post_build` method of the class that we'll use inside our `TensorNode`.

In [61]:
model.save_weights('fully_connected_weights.h5')

As in the first `TensorNode` tutorial, we write a class with a `pre_build` method that specifies the structure of our neural network, this time doing so with the high-level Keras API (by literally cutting and pasting the code we used to define our network when training it). We also use the `post_build` method where we can load the weights, after our network has been built as a dataflow graph that `Tensorflow` can execute within its C++ runtime. And again as before we provide a `__call__` method that gets executed at each timestep of our simulation that will be run with Nengo.

Notice that in the `__call__` method, we simply pass our input tensor to the Keras model *without* calling a method such as `Model.predict` that you might intuitively write if you frequently work with the Keras API. We do this because we want the model to return a Tensor object, not something like computed predictions (e.g., a Numpy array of floats). This way the Tensor can become part of the dataflow graph that Tensorflow creates for its C++ runtime when acting as the Nengo backend.

In [62]:
model_weights = 'fully_connected_weights.h5'
image_shape = (28, 28)

class FullyConnectedNode:
    def pre_build(self, *args):
        self.model = keras.Sequential([
            keras.layers.Flatten(input_shape=image_shape),
            keras.layers.Dense(128, activation=tf.nn.relu),
            keras.layers.Dense(10, activation=tf.nn.softmax)
        ])

    def post_build(self, sess, rng):
        # load checkpoint file into model
        self.model.load_weights(model_weights)

    def __call__(self, t, x):
        images = tf.reshape(x, (-1,) + image_shape)
        return self.model(images)

To better understand the difference between just writing `model(images)` and `model.predict(images)`, you can read and/or run the code in the cell below.

In [54]:
model_weights = 'fully_connected_weights.h5'
image_shape = (28, 28)

with tf.Session() as sess:
    model = keras.Sequential([
            keras.layers.Flatten(input_shape=(28, 28), name='flatten'),
            keras.layers.Dense(128, activation=tf.nn.relu, name='hidden'),
            keras.layers.Dense(10, activation=tf.nn.softmax, name='softmax')
        ])
    model.load_weights('fully_connected_weights.h5')
    out1 = model(tf.convert_to_tensor(test_images[:10], dtype=tf.float32))
    out2 = model.predict(test_images[:10])

print("Type of 'out1': ", type(out1))
print("Type of 'out2': ", type(out2))

Type of 'out1':  <class 'tensorflow.python.framework.ops.Tensor'>
Type of 'out2':  <class 'numpy.ndarray'>


Now that we have our `FullyConnectedNode` class, we can use it to insert our fully-connected Keras model into a Nengo network via a `TensorNode`.

Notice here that we use a numpy `ones` vector as a dummy `output` from our `input_node`; it  will be replaced by our Fashion MNIST images, flattened into vectors, when we run the `Simulator`.

In [65]:
net_input_shape = np.prod(image_shape)  # because input will be a vector

with nengo.Network() as net:
    # create a normal input node to feed in our test image
    input_node = nengo.Node(output=np.ones((net_input_shape,)),
                                           label='input')

    # create our TensorNode containing the FullyConnectedNode() we defined
    # above.  we also need to specify size_in (the dimensionality of
    # our input vectors, the flattened images) and size_out (the number
    # of classification classes output by the inception network)
    fc_node = nengo_dl.TensorNode(
        FullyConnectedNode(),
        size_in=net_input_shape,
        size_out=num_classes)

    # connect up our input to our fully-connected network node
    nengo.Connection(input_node, fc_node, synapse=None)

    # add some probes to collect data
    input_p = nengo.Probe(input_node)
    fc_p = nengo.Probe(fc_node)

In this very simple example, we only want to demonstrate *how* to use a `TensorNode` in a Nengo model, so we don't overcomplicate things with features from Nengo. However, typically you would have a Nengo model that integrates inputs over time; the simulator expects you to run for some number of time steps. So we specify a number of time steps.

In [66]:
n_steps = 200

We also grab some images at random from our test set. Here we flatten them into vectors so we can pass them to the input node of our Nengo network.

(20, 200, 784)

In [71]:
minibatch_size = 20

# flatten the test set so we can pass images as vectors to the input node
test_images_flat = test_images.reshape([-1, np.prod(image_shape)])
# grab some random images from test set
inds = np.random.randint(low=0, high=test_images.shape[0], size=(minibatch_size,))
# tile so that we have (minibatch size, time steps, images)
minibatch = np.tile(test_images_flat[inds, None, :],(1, n_steps, 1)).shape
input_feeds = {input_node: minibatch}

Finally we are ready to run the simulation.

In [72]:
sim = nengo_dl.Simulator(net, minibatch_size=minibatch_size)

with sim:
    sim.run_steps(n_steps, input_feeds = test_inputs)

Build finished in 0:00:00                                                      
Optimization finished in 0:00:00                                               
Construction finished in 0:00:00                                               


TypeError: Cannot interpret feed_dict key as Tensor: Tensor Tensor("Placeholder_4:0", shape=(784, 128), dtype=float32) is not an element of this graph.