# Freeze a keras model

This notebook freeze a keras model saved with no optimizer, so we don't need to define custom objects for any specific loss function or metrics

In [1]:
''' Import Keras Modules '''
from keras.models import Sequential,Model, load_model,model_from_config
from keras import backend as K


  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [2]:
''' Import Tensorflow Modules '''
import tensorflow as tf
from tensorflow.python.framework import graph_io
from tensorflow.python.tools import freeze_graph
from tensorflow.core.protobuf import saver_pb2
from tensorflow.python.training import saver as saver_lib

Instructions for updating:
Use the retry module or similar alternatives.


In [3]:
CONFIG = {
    # Where to save models
    "graphdef_file": "./models/keras_graphdef.pb",
    "frozen_model_file": "./models/keras_frozen_model.pb",
    "snapshot_dir": "./models/snapshot",
}

In [4]:
def freeze_model(model_file):

    # open up a Tensorflow session
    sess = tf.Session()
    # tell Keras to use the session
    K.set_session(sess)
    K.set_learning_phase(0)  # all new operations will be in test mode from now on

    # serialize the model and get its weights, for quick re-building
    model = load_model(model_file)
    model.summary()

    config  = model.get_config()
    weights = model.get_weights()

    # re-build a model where the learning phase is now hard-coded to 0
    try:
        model= Sequential.from_config(config) 
    except:
        model= Model.from_config(config)
        
    sess.run(tf.global_variables_initializer())
   
    model.set_weights(weights)

    # Now, let's use the Tensorflow backend to get the TF graphdef and frozen graph
    saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2)
    # save model weights in TF checkpoint
    checkpoint_path = saver.save(sess, CONFIG['snapshot_dir'], global_step=0, latest_filename='checkpoint_state')

    # remove nodes not needed for inference from graph def
    train_graph = sess.graph
    inference_graph = tf.graph_util.remove_training_nodes(train_graph.as_graph_def())

    # write the graph definition to a file. 
    # You can view this file to see your network structure and 
    # to determine the names of your network's input/output layers.
    graph_io.write_graph(inference_graph, '.', CONFIG['graphdef_file'])

    print("Input names:")
    #print(model.input.name)
    for inp in model.input:
        print(inp.name)
    print("Output name:")
    print(model.output.name)

    # specify which layer is the output layer for your graph. 
    out_names = model.output.name.split(':')[0]

    # freeze your inference graph and save it for later! (Tensorflow)
    freeze_graph.freeze_graph(
        CONFIG['graphdef_file'], 
        '', 
        False, 
        checkpoint_path, 
        out_names, 
        "save/restore_all", 
        "save/Const:0", 
        CONFIG['frozen_model_file'], 
        False, 
        ""
    )

In [5]:
freeze_model('./model-noop.h5')



__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
img_i (InputLayer)              (None, 224, 544, 3)  0                                            
__________________________________________________________________________________________________
img_f (InputLayer)              (None, 224, 544, 3)  0                                            
__________________________________________________________________________________________________
conv2d_27 (Conv2D)              (None, 224, 544, 16) 1216        img_i[0][0]                      
__________________________________________________________________________________________________
conv2d_29 (Conv2D)              (None, 224, 544, 16) 448         img_f[0][0]                      
__________________________________________________________________________________________________
batch_norm

Input names:
img_i_1:0
img_f_1:0
Output name:
class_1/Sigmoid:0
INFO:tensorflow:Restoring parameters from ./models/snapshot-0
INFO:tensorflow:Froze 156 variables.
Converted 156 variables to const ops.
