TensorFlow provides APIs for use in Java programs. ([link](https://www.tensorflow.org/install/install_java))

Scala runs on the JVM, so Java and Scala stacks can be freely mixed for totally seamless integration. ([link](https://www.scala-lang.org/))

The LabelImage example demonstrates use of this API to classify images using a pre-trained Inception architecture convolutional neural network. ([link](https://www.tensorflow.org/api_docs/java/reference/org/tensorflow/package-summary))

It demonstrates:

* Graph construction: using the OperationBuilder class to construct a graph to decode, resize and normalize a JPEG image.
* Model loading: Using Graph.importGraphDef() to load a pre-trained Inception model.
* Graph execution: Using a Session to execute the graphs and find the best label for an image.

If you see a standalone TensorFlow file representing a model, it's likely to contain a serialized version of one of these GraphDef objects saved out by the protobuf code. ([link](https://www.tensorflow.org/extend/tool_developers/#graphdef))

It's often not very convenient to have separate files when you're deploying to production, so there's the [freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py) script that takes a graph definition and a set of checkpoints and freezes them together into a single file. ([link](https://www.tensorflow.org/extend/tool_developers/#freezing))

You can also look at freeze_graph_test.py for an example of how to use it. ([link](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py#L33))

If your graph makes use of the Keras learning phase (different behavior at training time and test time), the very first thing to do before exporting your model is to hard-code the value of the learning phase (as 0, presumably, i.e. test mode) into your graph. ([link](https://blog.keras.io/keras-as-a-simplified-interface-to-tensorflow-tutorial.html#exporting-a-model-with-tensorflow-serving))


In [1]:
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests the graph freezing tool."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

from tensorflow.core.example import example_pb2
from tensorflow.core.framework import graph_pb2
from tensorflow.core.protobuf import saver_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import graph_io
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.saved_model import builder as saved_model_builder
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.tools import freeze_graph
from tensorflow.python.training import saver as saver_lib


class FreezeGraphTest(test_util.TensorFlowTestCase):
    
    def _testFreezeGraph(self, saver_write_version):
        
        checkpoint_prefix = os.path.join(self.get_temp_dir(), "saved_checkpoint")
        checkpoint_meta_graph_file = os.path.join(self.get_temp_dir(),
                                                  "saved_checkpoint.meta")
        checkpoint_state_name = "checkpoint_state"
        input_graph_name = "input_graph.pb"
        output_graph_name = "output_graph.pb"

        # We'll classify Kaggle's "Planet: Understanding the Amazon from Space" classes
        # with Raster Vision's ResNet50.
        with ops.Graph().as_default():
            # https://blog.keras.io/keras-as-a-simplified-interface-to-tensorflow-tutorial.html#exporting-a-model-with-tensorflow-serving
            from keras import backend as K
            
            K.set_learning_phase(0) # all new operations will be in test mode from now on
            
            # https://keras.io/applications/#usage-examples-for-image-classification-models
            from rastervision.common.models.resnet50 import ResNet50
            from keras.preprocessing import image
            from keras.applications.resnet50 import preprocess_input
            import numpy as np

            # Based on Raster Vision's Planet Kaggle tagging code and options.json
            classes = ['agriculture', 'artisinal_mine', 'bare_ground', 'blooming', 'blow_down', 'clear', 'cloudy', 'conventional_mine', 'cultivation', 'habitation', 'haze', 'partly_cloudy', 'primary', 'road', 'selective_logging', 'slash_burn', 'water']
            model = ResNet50(
                include_top=True, weights='imagenet',
                input_shape=(256, 256, 3),
                classes=len(classes),
                activation='sigmoid')
            model.load_weights("best_model.h5", by_name=True)

            img_path = 'train_1000.jpg' # training image from Kaggle
            img = image.load_img(img_path, target_size=(256, 256))
            x = image.img_to_array(img)
            x = np.expand_dims(x, axis=0)
            x = preprocess_input(x)

            preds = model.predict(x)
            # decode the results into a list of tuples (class, description, probability)
            # (one such list for each sample in the batch)
            print('Predicted:', sorted(zip(classes, preds[0]), key=lambda d: d[1], reverse=True)[0:5])

            sess = K.get_session() # https://www.tensorflow.org/api_docs/python/tf/keras/backend/get_session
            # Test if equal to earlier predictions
            input_node = sess.graph.get_tensor_by_name("input_1:0")
            output_node = sess.graph.get_tensor_by_name("dense/Sigmoid:0")
            output = sess.run(output_node, feed_dict={input_node: x})
            print('Predicted:', sorted(zip(classes, preds[0]), key=lambda d: d[1], reverse=True)[0:5])
            
            saver = saver_lib.Saver(write_version=saver_write_version)
            checkpoint_path = saver.save(
                sess,
                checkpoint_prefix,
                global_step=0,
                latest_filename=checkpoint_state_name)
            graph_io.write_graph(sess.graph, self.get_temp_dir(), input_graph_name)

        # We save out the graph to disk, and then call the const conversion
        # routine.
        input_graph_path = os.path.join(self.get_temp_dir(), input_graph_name)
        input_saver_def_path = ""
        input_binary = False
        output_node_names = "dense/Sigmoid"
        restore_op_name = "save/restore_all"
        filename_tensor_name = "save/Const:0"
        output_graph_path = os.path.join(self.get_temp_dir(), output_graph_name)
        clear_devices = False
        input_meta_graph = checkpoint_meta_graph_file
        
        freeze_graph.freeze_graph(
            input_graph_path, input_saver_def_path, input_binary, checkpoint_path,
            output_node_names, restore_op_name, filename_tensor_name,
            output_graph_path, clear_devices, "", "")#, input_meta_graph) # freeze_graph() got an unexpected keyword argument 'input_meta_graph'

        # Now we make sure the variable is now a constant, and that the graph still
        # produces the expected result.
        with ops.Graph().as_default():
            output_graph_def = graph_pb2.GraphDef()
            with open(output_graph_path, "rb") as f:
                output_graph_def.ParseFromString(f.read())
                _ = importer.import_graph_def(output_graph_def, name="")

            for node in output_graph_def.node:
                self.assertNotEqual("VariableV2", node.op)
                self.assertNotEqual("Variable", node.op)
                
            with session.Session() as sess:
                input_node = sess.graph.get_tensor_by_name("input_1:0")
                output_node = sess.graph.get_tensor_by_name("dense/Sigmoid:0")
                output = sess.run(output_node, feed_dict={input_node: x})
                print('Predicted:', sorted(zip(classes, preds[0]), key=lambda d: d[1], reverse=True)[0:5])

In [2]:
FreezeGraphTest()._testFreezeGraph(saver_pb2.SaverDef.V2)

Using TensorFlow backend.


Predicted: [('habitation', 1.0), ('partly_cloudy', 1.0), ('primary', 1.0), ('road', 0.040755637), ('cultivation', 4.5596539e-29)]
Predicted: [('habitation', 1.0), ('partly_cloudy', 1.0), ('primary', 1.0), ('road', 0.040755637), ('cultivation', 4.5596539e-29)]
INFO:tensorflow:Froze 320 variables.
Converted 320 variables to const ops.
1253 ops in the final graph.
Predicted: [('habitation', 1.0), ('partly_cloudy', 1.0), ('primary', 1.0), ('road', 0.040755637), ('cultivation', 4.5596539e-29)]
