In [4]:
from __future__ import print_function
import tensorflow as tf

In [2]:
# 1. crate a graph
graph = tf.Graph()
with graph.as_default():
    a = tf.constant(1)
    b = tf.constant(2)
    c = a+b

In [6]:
with tf.Session(graph=graph) as sess:
    print(sess.run(c)) # 3, nice

3


In [3]:
# 2. Serialize it with protocol buffer
graph_def = graph.as_graph_def()
with open("tf_import_graph_def.pb", "wb") as wf:
    wf.write(graph_def.SerializeToString())

In [7]:
# 3. Load it back
new_graph_def = tf.GraphDef()
with open("tf_import_graph_def.pb", "rb") as rf:
    new_graph_def.ParseFromString(rf.read())

new_graph = tf.Graph()
with new_graph.as_default():
    tf.import_graph_def(new_graph_def, name="")

In [8]:
new_c = new_graph.get_tensor_by_name(c.name)

In [9]:
with tf.Session(graph=new_graph) as sess:
    print(sess.run(new_c)) # 3

3


### Don't use `tf.import_graph_def` with `Graph` containing `Variable`

In [10]:
# A simple test
variable_graph = tf.Graph()
with variable_graph.as_default():
    a = tf.Variable(3.0, dtype=tf.float32, name="a")
    loss = (a - 1.0)**2

In [11]:
with variable_graph.as_default():
    print(tf.trainable_variables()) # Ok, we have one variable to optimize on

[<tf.Variable 'a:0' shape=() dtype=float32_ref>]


In [12]:
# serialize the graph
with open("tf_import_graph_def_variable.pb", "wb") as wf:
    graph_def = variable_graph.as_graph_def()
    wf.write(graph_def.SerializeToString())

In [13]:
# load it back
new_variable_graph_def = tf.GraphDef()
with open("tf_import_graph_def_variable.pb", "rb") as rf:
    new_variable_graph_def.ParseFromString(rf.read())

new_graph = tf.Graph()
with new_graph.as_default():
    tf.import_graph_def(new_variable_graph_def, name="")

Everything seems alright, but....

In [14]:
with new_graph.as_default():
    print(tf.trainable_variables()) # `a` is gone!

[]


In [17]:
new_a = new_graph.get_tensor_by_name("a:0")
print(new_a.name, new_a.op.type)

a:0 VariableV2


In [20]:
print(type(a), type(new_a))

<class 'tensorflow.python.ops.variables.Variable'> <class 'tensorflow.python.framework.ops.Tensor'>


So, `new_graph` is in fact a graph can't be trained.

`tf.import_graph_def` fails to load varible back (I think it's a bug....)

What do we do?

For now, I suggest use `tf.train.Saver` to save a trainable graph.

Only use graph serialization/deserialization with constant graph (freeze graph)