## Getting Checkpoint file

This sample code will show how to freeze tensorflow snapshot using tensorflow model_freezing tool. The snapshot files are slim based files.

In [1]:
import urllib3
import tarfile
import os

In [2]:
path_ckpt='snapshots'
path_frozen_models='frozen_models'

if (os.path.exists(path_ckpt) is False):
    os.mkdir(path_ckpt)
if (os.path.exists(path_frozen_models) is False):
    os.mkdir(path_frozen_models)

Put tensorflow snapshot files as you want.

In [3]:
link='http://download.tensorflow.org/models'
ls_model_file=[
        'inception_v3_2016_08_28.tar.gz',
        'resnet_v1_50_2016_08_28.tar.gz',
        'resnet_v2_50_2017_04_14.tar.gz',
        'mobilenet_v1_0.50_160_2017_06_14.tar.gz']

### Download tensorflow snapshot files and untar them

In [4]:
import wget
def download_ckpt(link, ls_model):
    for model in ls_model:
        link_ = os.path.join(link, model)
        print('Downloading', model, '...  ', end="")
        if (os.path.exists(os.path.join(path_ckpt, model)) is False):
            wget.download(link_, out=os.path.join(path_ckpt, model))
            print('done')
        else:
            print('used already exist!!')
        
download_ckpt(link, ls_model_file)

Downloading inception_v3_2016_08_28.tar.gz ...  used already exist!!
Downloading resnet_v1_50_2016_08_28.tar.gz ...  used already exist!!
Downloading resnet_v2_50_2017_04_14.tar.gz ...  used already exist!!
Downloading mobilenet_v1_0.50_160_2017_06_14.tar.gz ...  used already exist!!


In [5]:
import tarfile,sys

def untar_ckpt(fname):
    with tarfile.open(fname) as tar:
        model_name = [file for file in tar.getnames() if "ckpt" in file][0].split('.')[0]
        print(model_name, end='')
        if (os.path.exists(os.path.join(path_ckpt, model_name)) is False):
            tar.extractall(os.path.join(path_ckpt, model_name))
            print('.. done')
        else:
            print('.. already exist!!')
    return model_name
            
ls_model = []
for model_file in ls_model_file:
    model = untar_ckpt(os.path.join('./', path_ckpt, model_file))
    ls_model.append(model)

inception_v3.. already exist!!
resnet_v1_50.. already exist!!
resnet_v2_50.. already exist!!
mobilenet_v1_0.. already exist!!


### Import tensorflow  & dependencies

In [6]:
import tensorflow as tf
from tensorflow.python.framework import graph_io
from tensorflow.python.tools import freeze_graph
from tensorflow.core.protobuf import saver_pb2
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.saved_model import tag_constants

  return f(*args, **kwds)


Instructions for updating:
Use the retry module or similar alternatives.


Since we are using slim model, we need to get slim graph from tensorflow

In [7]:
if (os.path.exists('models') is False):
    os.system('git clone -b 2d7a0d6abba764b768d645947014492ade492385 https://github.com/tensorflow/models')
else:
    print('Confirmed that tensorflow models directory existance')

Confirmed that tensorflow models directory existance


In [8]:
from models.research.slim.nets import nets_factory

### Select model from the list

In [9]:
model_name = ls_model[2]
print("Selected model is", model_name)

Selected model is resnet_v2_50


In [10]:
graphdef_file = os.path.join(path_ckpt, model_name, model_name + '_graph.pb')
checkpoint_path = os.path.join(path_ckpt, model_name)
frozenmodel_path = os.path.join(path_ckpt, model_name, model_name + '_frozen.pb')

And we export graphdef file to get graph's output layer name

In [11]:
%%bash -s "$model_name" "$graphdef_file"
echo "Model name: $1"
echo "Graph path: $2"
python models/research/slim/export_inference_graph.py \
    --model_name=$1 \
    --output_file=$2

Model name: resnet_v2_50
Graph path: snapshots/resnet_v2_50/resnet_v2_50_graph.pb


  return f(*args, **kwds)
Instructions for updating:
Use the retry module or similar alternatives.
INFO:tensorflow:Scale of 0 disables regularizer.
Instructions for updating:
keep_dims is deprecated, use keepdims instead


In [12]:
graphdef_file = os.path.join(path_ckpt, model_name, model_name + '_graph.pbtxt')
with tf.Graph().as_default() as graph:
    network_fn = nets_factory.get_network_fn(
        model_name,
        num_classes=1001,
        is_training=False
    )
    image_size = network_fn.default_image_size
    inputs = tf.random_uniform((8, image_size, image_size, 3))
    logits, end_points = network_fn(inputs)
    out_layer = list(end_points.items())[-1][1].name.split(':')[0]
    print('output_layer:', out_layer)
    
    # remove nodes not needed for inference from graph def
    inference_graph = tf.graph_util.remove_training_nodes(graph.as_graph_def())
    
    # write the graph definition to a file
    graph_io.write_graph(inference_graph, '.', graphdef_file)
    

INFO:tensorflow:Scale of 0 disables regularizer.
Instructions for updating:
keep_dims is deprecated, use keepdims instead
output_layer: resnet_v2_50/predictions/Reshape_1


Result files are;

In [13]:
print("Model Name:\t\t", model_name)
print("Output Layer:\t\t", out_layer)
print("Model's Graphdef file:\t", graphdef_file)
print("Frozen model file's path:", frozenmodel_path)
print("Source ckpt file path:\t", os.path.join(path_ckpt, model_name, model_name + '.ckpt'))

Model Name:		 resnet_v2_50
Output Layer:		 resnet_v2_50/predictions/Reshape_1
Model's Graphdef file:	 snapshots/resnet_v2_50/resnet_v2_50_graph.pbtxt
Frozen model file's path: snapshots/resnet_v2_50/resnet_v2_50_frozen.pb
Source ckpt file path:	 snapshots/resnet_v2_50/resnet_v2_50.ckpt


## Model Freesing
Finally, we uses tensorflow's freeze_graph tool. It is possible that some checkpoints files cannot be freezed.

In [14]:
freeze_graph.freeze_graph(
    input_graph=graphdef_file,
    input_checkpoint=os.path.join(path_ckpt, model_name, model_name + '.ckpt'),
    input_binary=False,
    output_graph=frozenmodel_path,
    output_node_names=out_layer,
    input_saver="",
    restore_op_name="save/restore_all",
    filename_tensor_name="save/Const:0",
    clear_devices=True,
    initializer_nodes=""
)

INFO:tensorflow:Restoring parameters from snapshots/resnet_v2_50/resnet_v2_50.ckpt
INFO:tensorflow:Froze 272 variables.
Converted 272 variables to const ops.
