In [46]:
import tensorflow as tf
from tensorflow.python.framework import graph_io
from keras.models import load_model
from keras_efficientnets import custom_objects
from tensorflow.keras import initializers
import copy

In [47]:
tf.keras.backend.clear_session()

In [48]:
save_pb_dir = '.'
model_fname = 'Efficientnet_model_weights_NEW4_trial4.h5'
def freeze_graph(graph, session, output, save_pb_dir='.', save_pb_name='frozen_model.pb', save_pb_as_text=False):
    with graph.as_default():
        
        graph_def = graph.as_graph_def()
        graphdef_inf = tf.graph_util.remove_training_nodes(graph_def)
        
        for function_def in graph_def.library.function:
            if function_def.signature.name == 'swish_f32':
                graphdef_inf.library.function.extend([copy.deepcopy(function_def)])
        
        graphdef_frozen = tf.graph_util.convert_variables_to_constants(session, graphdef_inf, output)
        graph_io.write_graph(graphdef_frozen, save_pb_dir, save_pb_name, as_text=save_pb_as_text)
        return graphdef_frozen

In [49]:
tf.keras.backend.set_learning_phase(0) 

model = load_model(model_fname)

session = tf.keras.backend.get_session()

input_names = [t.op.name for t in model.inputs]
output_names = [t.op.name for t in model.outputs]

# Prints input and output nodes names, take notes of them.
print(input_names, output_names)

frozen_graph = freeze_graph(session.graph, session, [out.op.name for out in model.outputs], save_pb_dir=save_pb_dir)

['model_2_input'] ['dense_1/Softmax']
INFO:tensorflow:Froze 311 variables.
INFO:tensorflow:Converted 311 variables to const ops.


In [50]:
import tensorflow.contrib.tensorrt as trt

trt_graph = trt.create_inference_graph(
    input_graph_def=frozen_graph,
    outputs=output_names,
    max_batch_size=1,
    max_workspace_size_bytes=1 << 25,
    precision_mode='FP16',
    minimum_segment_size=50
)

INFO:tensorflow:Linked TensorRT version: (0, 0, 0)
INFO:tensorflow:Loaded TensorRT version: (0, 0, 0)
INFO:tensorflow:Running against TensorRT version 0.0.0


In [51]:
graph_io.write_graph(trt_graph, "./model/",
                     "trt_graph.pb", as_text=False)

'./model/trt_graph.pb'