# Using a TensorNode class to insert a network built with the Tensorflow Keras API into a Nengo model
This tutorial builds on [Inserting a TensorFlow network into a Nengo model](./pretrained-model.ipynb). In that tutorial, we showed how to write a TensorNode class that makes it easy to insert a pre-trained network from `tf.contrib.slim` into a Nengo model.  
Sometimes, instead of using a pre-trained network, you may want to use your own neural net architecture within a TensorNode. Here we show how a network built with the `tf.keras` API can be insterted into a Nengo model using the TensorNode class. If you haven't read the previous tutorial, please work through that one to familiarize yourself with the concept we'll use here. This tutorial also assumes familiarity with the `tf.keras` API. Specifically it is based on the [introduction in the Tensorflow documentation](https://www.tensorflow.org/tutorials/keras/basic_classification), so if you are not yet familiar with Keras, you may find it helpful to read those tutorials first as well.

In [3]:
import numpy as np
import tensorflow as tf
from tensorflow import keras

import nengo
import nengo_dl

In [4]:
fashion_mnist = keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()


In [3]:
model = keras.Sequential([
            keras.layers.Flatten(input_shape=(28, 28)),
            keras.layers.Dense(128, activation=tf.nn.relu),
            keras.layers.Dense(10, activation=tf.nn.softmax)
        ])

model.compile(optimizer=tf.train.AdamOptimizer(), 
                      loss='sparse_categorical_crossentropy',
                      metrics=['accuracy'])

In [5]:
train_images = train_images / 255.0

test_images = test_images / 255.0

In [5]:
model.fit(train_images, train_labels, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x7fa57d466da0>

In [8]:
model.save_weights('fully_connected_weights.h5')

In [6]:
model_weights = 'fully_connected_weights.h5'
image_shape = (28, 28)

class FullyConnectedNode(object):
    def pre_build(self, *args):
        self.model = keras.Sequential([
            keras.layers.Flatten(input_shape=image_shape),
            keras.layers.Dense(128, activation=tf.nn.relu),
            keras.layers.Dense(10, activation=tf.nn.softmax)
        ])

    def post_build(self, sess, rng):
        # load checkpoint file into model
        self.model.compile(optimizer=tf.train.AdamOptimizer(), 
                      loss='sparse_categorical_crossentropy',
                      metrics=['accuracy'])
        self.model.load_weights(model_weights)

    def __call__(self, t, x):
        # this is the function that will be executed each timestep 
        # while the network is running
                
        # convert our input vector to the shape/dtype of the input image
        images = tf.reshape(x, (-1,) + image_shape)

        probabilities = self.model.predict(images)
        return probabilities

In [7]:
train_shape = train_images.shape

with nengo.Network() as net:
    # create a normal input node to feed in our test image
    input_node = nengo.Node([0] * np.prod(train_shape[1:]))

    # create our TensorNode containing the FullyConnectedNode() we defined
    # above.  we also need to specify size_in (the dimensionality of
    # our input vectors, the flattened images) and size_out (the number
    # of classification classes output by the inception network)
    fc_node = nengo_dl.TensorNode(
        FullyConnectedNode(),
        size_in=np.prod(train_shape[1:]),
        size_out=10)

    # connect up our input to our fully-connected etwork node
    nengo.Connection(input_node, fc_node, synapse=None)

    # add some probes to collect data
    input_p = nengo.Probe(input_node)
    fc_p = nengo.Probe(fc_node)

In [8]:
minibatch_size = 10
sim = nengo_dl.Simulator(net, minibatch_size=minibatch_size)

Build finished in 0:00:00                                                      
Optimization finished in 0:00:00                                               
|           Constructing graph: pre-build stage (0%)           | ETA:  --:--:--

  "No GPU support detected. It is recommended that you "


Construction finished in 0:00:00                                               


ValueError: When using data tensors as input to a model, you should specify the `steps` argument.