# Optimize for inference

This notebook will optimize a frozen graph using TensorFlow transform graph functions.

In [1]:
CONFIG = {
    # Where to save models
    "graphdef_file": "./models/keras_graphdef.pb",
    "frozen_model_file": "./models/keras_frozen_model.pb",
    "snapshot_dir": "./models/snapshot",
    "opt_model_file" : "./models/keras_opt_model.pb",
}

In [2]:
import tensorflow as tf

def load_graph_for_transform(frozen_graph_filename):
    # We load the protobuf file from the disk and parse it to retrieve the 
    # unserialized graph_def
    with tf.gfile.GFile(frozen_graph_filename, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        
    return graph_def

def load_graph(frozen_graph_filename):
    # We import the graph_def into a new Graph and returns it 
    with tf.Graph().as_default() as graph:
        # The name var will prefix every op/nodes in your graph
        # Since we load everything in a new graph, this is not needed
        tf.import_graph_def(load_graph_for_transform(frozen_graph_filename), name='')
    return graph


  from ._conv import register_converters as _register_converters


In [3]:
from tensorflow.tools.graph_transforms import TransformGraph
from tensorflow.python.framework       import graph_io

# Load the frozen graph
graph = load_graph_for_transform('./models/keras_frozen_model.pb')

# Transform it
input_names = ['img_i_1', 'img_f_1']
output_names = ['class_1/Sigmoid']
transforms = ['strip_unused_nodes(type=float, shape="1,224,544,3")',
              'remove_nodes(op=Identity, op=CheckNumerics)',
              'fold_constants(ignore_errors=true)',
              'fold_batch_norms',
              'fold_old_batch_norms',
             ]

G_opt = TransformGraph(graph, input_names, output_names, transforms)

# Write it to disk
with tf.gfile.GFile('./models/keras_opt_model.pb', "wb") as f:
    f.write(G_opt.SerializeToString())


#### Compare the number of operation before and after

In [4]:
graph = load_graph('./models/keras_frozen_model.pb')

print(len(graph.get_operations()))

#for op in graph.get_operations():
#    print(op.name)


501


In [5]:
graph = load_graph('./models/keras_opt_model.pb')
#graph = load_graph('./keras_opt_model_V19.pb')

print(len(graph.get_operations()))

#for op in graph.get_operations():
#    print(op.name)


293
