Skip to content

Preparing models for ofxMSATensorFlow

Memo Akten edited this page Jun 1, 2017 · 1 revision

This document contains information on how to prepare, train and export models to be used in openframeworks / ofxMSATensorFlow.

Short Version

  • The trained model needs to be exported as a single self contained, pruned and frozen graph definition protobuf (.pb).
  • Any operators used for input (feeding) and fetching (output) needs to have a name, that you know.

Long Version

A common overview of setting up and training a model in tensorflow can be summarised as:

  • Define a model architecture (aka graph definition aka graph_def). This consists of a bunch of nodes (aka operators aka ops), which include all kinds of mathematical operations, variables, etc
  • Do a bunch of other things (e.g. define some kind of a cost function that needs to be minimised, choose an optimisation algorithm etc)
  • Train (run the optimisation)

The act of training modifies the variables such that the cost function is minimised.

Currently the most popular environment for building and training models is the tensorflow python API. It is the most complete API, and most of the tensorflow community use it, and most examples out there are in python.

There are two main things to look out for when using your own trained models (i.e. graphs) in ofxMSATensorFlow (or any C++ applcation AFAIK)

1. Graph Definition vs Variables

The default file format for saving the results of training a tensorflow model is a checkpoint (.ckpt) file. However, these files only contain values for variables (by name), and don't contain any architecture information, so loading them by themselves in C++ wouldn't be enough (AFAIK there is no C++ loader for ckpt files anyway).

In order to load and use checkpoint files, one must first construct the architecture exactly as it was when the checkpoint file was saved, and then load the checkpoint file. For this reason, most python tensorflow examples first build the architecture by code, and then load the checkpoint file. In fact they often use the same architecture building code for training and for inference, by simply including the same .py file. This is great for python users, but not for model distribution.

The file format which contains architecture information (also referred to as the graph definition) and can be loaded in C++ is protobuf (.pb) - more recent versions of tensorflow also have metagraphs (.meta) which include this. However, these files only contain the architecture and not the trained variable values!

Alternatively, graph freezing is a process whereby after training, the variables in a graph are replaced by constants of the same value (bear in mind there might be millions of variables in a graph). Also, it's common to prune the graph, i.e. remove parts of the graph which is only needed for training, and not necessary for inference (e.g. all the stuff used for calculating the cost, managing gradients etc). Once pruned and frozen, a graph_def .pb can be saved as normal with tf.train.write_graph and it will contain the trained variable values (as constants) and will not have any unnecessary nodes.

Tensorflow provide a utility to do this called

Alternatively, there are a few inbuilt tensorflow utility functions to do this as well, e.g. tf.graph_util.convert_variables_to_constants. I demonstrate this in my versions of pix2pix-tensorflow, char-rnn-tensorflow, write-rnn-tensorflow. In summary, with one line of code you can prune the graph and convert variables to constants:

            tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, [output nodes])

2. Referencing Nodes for Feed and Fetch.

The basic idea of 'running a model' for inference in tensorflow is that you 'feed' various input nodes some data, and you request to 'fetch' the output of various other nodes in the graph. There isn't necessarily one input and one output. The architecture can be quite complicated, hence it's called a 'graph'. This is why we need to specifically indicate which bits of data we are feeding to which nodes, and which nodes we would like to fetch the output from.

As mentioned previously, currently most python tensorflow examples build the architecture by code, and then save and load checkpoint files to restore variable values. When they need to reference nodes, they simply reference the nodes directly via the variable that addresses it.

However, if we were to load a frozen graph, e.g. in C++, then we don't have access to the variables that address the nodes. But we can reference them by name. This means that when building the model architecture, it helps to give special nodes (i.e. nodes that we'd like to write to and read from) memorable and easily identifiable names. This does not mean the names of the variables that we store the nodes in, but the actual names given to the 'name' parameter of the node. And if the node doesn't have a name parameter (some ops don't have names), then we might need to add extra nodes to be able to write to or read from.

E.g. this commit does exactly that.

In line 34 self.input_data = tf.placeholder(dtype=tf.float32, shape=[None, args.seq_length, 3], name='data_in') the name by which this node is referenced is data_in not input_data. The latter is the name of the variable which can only be used within this python program. The name that gets saved in the graph_def is data_in.

You will see quite a few tf.identity commands. E.g. line 52, 119-125. These are there simply to give the nodes a name. I.e. they take a node, add another node to it (which doesn't do much, just identity transform), but crucially, that node can have a name, which will be saved in the frozen graph can later be accessed in C++. (Note, tf.identity also merges multiple tensors together, but that's another detail).

Also sometimes stacking can help. E.g. in line 130 by stacking all of the Mixture Density Network parameters into a single op, I can reference them all by a single name.

You can’t perform that action at this time.
You signed in with another tab or window. Reload to refresh your session. You signed out in another tab or window. Reload to refresh your session.