In [1]:
import tensorflow as tf
import os
from google.protobuf import text_format

# Write graphdef to protobuf file

In [2]:
model_dir = 'graph'
if not os.path.exists(model_dir):
    os.makedirs(model_dir)

In [3]:
x = tf.placeholder(shape=(), dtype=tf.float32, name='input')
a = tf.Variable(initial_value=1.0, shape=(), dtype=tf.float32, name='add_value')
y = tf.add(x, a, name='output')

init = tf.global_variables_initializer()

* Get graphdef by using either `tf.Session().graph_def` or `tf.get_default_graph().as_graph_def()`
* Use `tf.io.write_graph()` to write graph proto to a file.

In [4]:
with tf.Session() as sess:
    sess.run(init)
    print(sess.run(y, feed_dict={x: 2.0}))
    tf.io.write_graph(graph_or_graph_def=sess.graph_def, 
                      logdir=model_dir, 
                      name='graph_def.pbtxt',
                      as_text=True # save as text format
                     )
    tf.io.write_graph(graph_or_graph_def=tf.get_default_graph().as_graph_def(), 
                      logdir=model_dir, 
                      name='graph_def.pb',
                      as_text=False # save as binary format
                     )

3.0


# Load graphdef file

## Text format
Use `google.protobuf.text_format.Merge()`

In [5]:
tf.reset_default_graph() # reset the graph

In [6]:
with open(os.path.join(model_dir, 'graph_def.pbtxt'), 'rb') as f:
    graph_def = tf.GraphDef() # create a graphdef
    text_format.Merge(f.read(), graph_def) # merge the protocol buffer message text representation into this graphdef

In [7]:
for node in graph_def.node:
    print(node.name)
    # node.name
    # node.op
    # node.input
    # node.device
    # node.attr

input
add_value/initial_value
add_value
add_value/Assign
add_value/read
output
init


In [8]:
tf.import_graph_def(graph_def) # import the graph from graphdef into the current default graph

In [9]:
for node in tf.get_default_graph().as_graph_def().node:
    print(node.name)

import/input
import/add_value/initial_value
import/add_value
import/add_value/Assign
import/add_value/read
import/output
import/init


In [10]:
with tf.Session() as sess:
    sess.run('import/init')
    print(sess.run('import/output:0', feed_dict={'import/input:0': 20.0}))

21.0


## Binary format
Use `tf.GraphDef.ParseFromString()`

In [11]:
tf.reset_default_graph() # reset the graph

In [12]:
with open(os.path.join(model_dir, 'graph_def.pb'), 'rb') as f:
    graph_def = tf.GraphDef() # create a graphdef
    graph_def.ParseFromString(f.read()) # load the protocol buffer message into this graphdef

In [13]:
for node in graph_def.node:
    print(node.name)

input
add_value/initial_value
add_value
add_value/Assign
add_value/read
output
init


In [14]:
tf.import_graph_def(graph_def) # import the graph from graphdef into the current default graph

In [15]:
for node in tf.get_default_graph().as_graph_def().node:
    print(node.name)

import/input
import/add_value/initial_value
import/add_value
import/add_value/Assign
import/add_value/read
import/output
import/init


In [16]:
inp = tf.get_default_graph().get_operation_by_name('import/input').outputs[0] # get the input tensor
out = tf.get_default_graph().get_operation_by_name('import/output').outputs[0] # get the output tensor

init_op = tf.get_default_graph().get_operation_by_name('import/init') # get the global variables initializer op

In [17]:
with tf.Session() as sess:
    sess.run(init_op)
    print(sess.run(out, feed_dict={inp: 100.0}))

101.0
