# imports

In [1]:
import tensorflow as tf
tf.keras.backend.set_learning_phase(0)

import tensorflow.keras.backend as K
import tensorflow.contrib.tensorrt as trt

## save utils

In [2]:
def save_graph_to_pb(graph, destination_path):
    with open(destination_path, 'wb') as f:
        f.write(graph.SerializeToString())

## load model

In [3]:
model = tf.keras.models.load_model('model/model.h5')

## freeze session

In [4]:
def freeze_session_into_graph(session, output_names=None, keep_var_names=None, clear_devices=True):
    graph = session.graph
    with graph.as_default():
        freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
        output_names = output_names or []
        output_names += [v.op.name for v in tf.global_variables()]
        input_graph_def = graph.as_graph_def()

        if clear_devices:
            for node in input_graph_def.node:
                node.device = ""

        frozen_graph = tf.graph_util.convert_variables_to_constants(
            session, input_graph_def, output_names, freeze_var_names)

        return frozen_graph

In [5]:
frozen_graph = freeze_session_into_graph(session=K.get_session(),
                                        output_names=[out.op.name for out in model.outputs])

INFO:tensorflow:Froze 975 variables.
INFO:tensorflow:Converted 975 variables to const ops.


In [6]:
save_graph_to_pb(frozen_graph, 'model/frozen.pb')

## tensor RT

In [7]:
def optimize_graph_with_tensorrt(frozen_graph, output_layers, precision_mode='FP32', 
                                 max_batch_size=1, workspace_bytes=2*10**9):
    
    trt_graph = trt.create_inference_graph(
        input_graph_def=frozen_graph,
        outputs=output_layers,
        max_batch_size=max_batch_size,
        max_workspace_size_bytes=workspace_bytes,
        precision_mode=precision_mode)

    return trt_graph

In [8]:
inference_graph = optimize_graph_with_tensorrt(frozen_graph, 
                                               [out.op.name for out in model.outputs], 
                                               max_batch_size=2)

INFO:tensorflow:Running against TensorRT version 0.0.0


In [9]:
print('The input graph has {} nodes'.format(sum([1 for _ in frozen_graph.node])))
print('The output graph has {} nodes'.format(sum([1 for _ in inference_graph.node])))

The input graph has 1564 nodes
The output graph has 602 nodes


In [10]:
save_graph_to_pb(inference_graph, 'model/inference.pb')

# check

In [11]:
def load_graph_from_pb(graph_path):
    tf.reset_default_graph()

    session = tf.Session()

    with tf.gfile.GFile(graph_path,'rb') as f:
        graph_def = tf.GraphDef()        
        graph_def.ParseFromString(f.read())
    
    session.graph.as_default()
    tf.import_graph_def(graph_def)
    return session

sess = load_graph_from_pb('model/inference.pb')

In [12]:
writer = tf.summary.FileWriter('model/', sess.graph)
writer.flush()
writer.close()

In [13]:
sess.graph.get_tensor_by_name("import/input_1:0")

<tf.Tensor 'import/input_1:0' shape=(?, 128, 128, 1) dtype=float32>

In [14]:
sess.graph.get_tensor_by_name("import/0_conv_1x1_parts/BiasAdd:0")

<tf.Tensor 'import/0_conv_1x1_parts/BiasAdd:0' shape=(?, 32, 32, 7) dtype=float32>