# [A Tool Developer's Guide to TensorFlow Model Files](https://www.tensorflow.org/extend/tool_developers/)

Most users shouldn't need to care about the internal details of how TensorFlow stores data on disk, but you might if you're a tool developer. For example, you may want to analyze models, or convert back and forth between TensorFlow and other formats. This guide tries to explain some of the details of how you can work with the main files that hold model data, to make it easier to develop those kind of tools.

### [Freezing](https://www.tensorflow.org/extend/tool_developers/#freezing)
One confusing part about this is that the weights usually aren't stored inside the file format during training. Instead, they're held in separate checkpoint files, and there are `Variable` ops in the graph that load the latest values when they're initialized. It's often not very convenient to have separate files when you're deploying to production, so there's the [freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py) script that takes a graph definition and a set of checkpoints and freezes them together into a single file.

What this does is load the `GraphDef`, pull in the values for all the variables from the latest checkpoint file, and then replace each Variable op with a `Const` that has the numerical data for the weights stored in its attributes It then strips away all the extraneous nodes that aren't used for forward inference, and saves out the resulting `GraphDef` into an output file.

# [freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py#L15-L33)

Converts checkpoint variables into Const ops in a standalone GraphDef file.

This script is designed to take a GraphDef proto, a SaverDef proto, and a set of
variable values stored in a checkpoint file, and output a GraphDef with all of
the variable ops converted into const ops containing the values of the
variables.

It's useful to do this when we need to load a single file in C++, especially in
environments like mobile or embedded where we may not have access to the
RestoreTensor ops and file loading calls that they rely on.

An example of command-line usage is:
bazel build tensorflow/python/tools:freeze_graph && \
bazel-bin/tensorflow/python/tools/freeze_graph \
--input_graph=some_graph_def.pb \
--input_checkpoint=model.ckpt-8361242 \
--output_graph=/tmp/frozen_graph.pb --output_node_names=softmax

You can also look at freeze_graph_test.py for an example of how to use it.

# [freeze_graph_test.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph_test.py#L15-L107) with [Keras ResNet50](https://keras.io/applications/#usage-examples-for-image-classification-models)

In [1]:
"""Tests the graph freezing tool."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

from tensorflow.core.example import example_pb2
from tensorflow.core.framework import graph_pb2
from tensorflow.core.protobuf import saver_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import graph_io
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.saved_model import builder as saved_model_builder
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.tools import freeze_graph
from tensorflow.python.training import saver as saver_lib


class FreezeGraphTest(test_util.TensorFlowTestCase):
    
    def _testFreezeGraph(self, saver_write_version):
        
        checkpoint_prefix = os.path.join(self.get_temp_dir(), "saved_checkpoint")
        checkpoint_meta_graph_file = os.path.join(self.get_temp_dir(),
                                                  "saved_checkpoint.meta")
        checkpoint_state_name = "checkpoint_state"
        input_graph_name = "input_graph.pb"
        output_graph_name = "output_graph.pb"

        with ops.Graph().as_default():
            # All new operations will be in test mode from now on
            # See: https://blog.keras.io/keras-as-a-simplified-interface-to-tensorflow-tutorial.html#exporting-a-model-with-tensorflow-serving
            from keras import backend as K
            K.set_learning_phase(0)
            
            from rastervision.common.models.resnet50 import ResNet50
            from keras.preprocessing import image
            from keras.applications.resnet50 import preprocess_input
            import numpy as np

            classes = ['agriculture', 'artisinal_mine', 'bare_ground', 'blooming', 'blow_down', 'clear', 'cloudy', 'conventional_mine', 'cultivation', 'habitation', 'haze', 'partly_cloudy', 'primary', 'road', 'selective_logging', 'slash_burn', 'water']
            model = ResNet50(
                include_top=True, weights='imagenet',
                input_shape=(256, 256, 3),
                classes=len(classes),
                activation='sigmoid')
            model.load_weights("best_model.h5", by_name=True)

            img_path = 'train_1000.jpg'
            img = image.load_img(img_path, target_size=(256, 256))
            x = image.img_to_array(img)
            x = np.expand_dims(x, axis=0)
            x = preprocess_input(x)

            preds = model.predict(x)
            # decode the results into a list of tuples (class, description, probability)
            # (one such list for each sample in the batch)
            print('Predicted:', sorted(zip(classes, preds[0]), key=lambda d: d[1], reverse=True)[0:5])
            # Predicted: [('n02510455', 'giant_panda', 0.99036354), ('n02500267', 'indri', 0.0058755828), ('n02509815', 'lesser_panda', 0.0016656154)]
            
            sess = K.get_session()
            input_node = sess.graph.get_tensor_by_name("input_1:0")
            output_node = sess.graph.get_tensor_by_name("dense/Sigmoid:0")
            output = sess.run(output_node, feed_dict={input_node: x})
            print('Predicted:', sorted(zip(classes, preds[0]), key=lambda d: d[1], reverse=True)[0:5])
            
            saver = saver_lib.Saver(write_version=saver_write_version)
            checkpoint_path = saver.save(
                sess,
                checkpoint_prefix,
                global_step=0,
                latest_filename=checkpoint_state_name)
            graph_io.write_graph(sess.graph, self.get_temp_dir(), input_graph_name)

        # We save out the graph to disk, and then call the const conversion
        # routine.
        input_graph_path = os.path.join(self.get_temp_dir(), input_graph_name)
        input_saver_def_path = ""
        input_binary = False
        output_node_names = "dense/Sigmoid"
        restore_op_name = "save/restore_all"
        filename_tensor_name = "save/Const:0"
        output_graph_path = os.path.join(self.get_temp_dir(), output_graph_name)
        clear_devices = False
        input_meta_graph = checkpoint_meta_graph_file
        
        freeze_graph.freeze_graph(
            input_graph_path, input_saver_def_path, input_binary, checkpoint_path,
            output_node_names, restore_op_name, filename_tensor_name,
            output_graph_path, clear_devices, "", "")#, input_meta_graph) # freeze_graph() got an unexpected keyword argument 'input_meta_graph'

        # Now we make sure the variable is now a constant, and that the graph still
        # produces the expected result.
        with ops.Graph().as_default():
            output_graph_def = graph_pb2.GraphDef()
            with open(output_graph_path, "rb") as f:
                output_graph_def.ParseFromString(f.read())
                _ = importer.import_graph_def(output_graph_def, name="")

            for node in output_graph_def.node:
                self.assertNotEqual("VariableV2", node.op)
                self.assertNotEqual("Variable", node.op)
                
            with session.Session() as sess:
                input_node = sess.graph.get_tensor_by_name("input_1:0")
                output_node = sess.graph.get_tensor_by_name("dense/Sigmoid:0")
                output = sess.run(output_node, feed_dict={input_node: x})
                print('Predicted:', sorted(zip(classes, preds[0]), key=lambda d: d[1], reverse=True)[0:5])

In [2]:
FreezeGraphTest()._testFreezeGraph(saver_pb2.SaverDef.V2)

Using TensorFlow backend.


Predicted: [('habitation', 1.0), ('partly_cloudy', 1.0), ('primary', 1.0), ('road', 0.040755637), ('cultivation', 4.5596539e-29)]
Predicted: [('habitation', 1.0), ('partly_cloudy', 1.0), ('primary', 1.0), ('road', 0.040755637), ('cultivation', 4.5596539e-29)]
INFO:tensorflow:Froze 320 variables.
Converted 320 variables to const ops.
1253 ops in the final graph.
Predicted: [('habitation', 1.0), ('partly_cloudy', 1.0), ('primary', 1.0), ('road', 0.040755637), ('cultivation', 4.5596539e-29)]
