# IBM Code Model Asset Exchange Image Caption Generator

## Setup

1. In a terminal window, run the following commands to set up the artifacts in yor local environment:
```
git clone https://github.com/IBM/MAX-Image-Caption-Generator
cd MAX-Image-Caption-Generator
curl -O http://max-assets.s3-api.us-geo.objectstorage.softlayer.net/tf/im2txt/im2txt_ckpt.tar.gz
tar -zxvf im2txt_ckpt.tar.gz -C assets/
```    

2. Copy this notebook into the `MAX-Image-Caption-Generator` directory.
3. Run the notebook `jupyter notebook .`

In [None]:
# This notebook has been tested with Python version 3.6
!python --version

In [None]:
# Install the latest package versions
!pip install -Iv dask
!pip install -Iv tensorflow
!pip install -Iv tensorflowjs
!pip install -Iv pandas
# Restart the kernel after installation completes.

In [None]:
# This notebook has been tested with dask 0.18, tensorflow 1.8, tensorflowjs 0.4.1, and pandas 0.23.1
!pip show tensorflow tensorflowjs dask pandas

In [None]:
import datetime
import pprint
import os
import pathlib
import time
from IPython.display import Image

In [None]:
import tensorflow as tf

print('TF versions:', tf.GIT_VERSION, tf.VERSION)

In [None]:
pp = pprint.PrettyPrinter(indent=4)

## Load model from checkpoint  and run a predicton

In [None]:
from core.backend import ModelWrapper

# instantiate model
m = ModelWrapper()


In [None]:
# # uncomment to see list graph nodes/tensors

#[n.name + '=>' +  n.op for n in m.sess.graph.as_graph_def().node]


In [None]:
# # uncomment to print a;; tensors in checkpoint file

#from tensorflow.python.tools import inspect_checkpoint as chkp
#chkp.print_tensors_in_checkpoint_file(checkpoint_dir_and_prefix, tensor_name='', all_tensors=True, all_tensor_names=False)


Display the test image and run prediction

In [None]:
# path to an image file to run prediction
test_image_path = './assets/plane.jpg'
# display image
Image(url=test_image_path)

In [None]:
# run prediction on image
with open(test_image_path, 'rb') as image:

    pp.pprint(m.predict(image.read()))

<br>
<hr>

## Create a frozen model from the downloaded checkpoint

- **[TensorFlow Model Files](https://www.tensorflow.org/extend/tool_developers/)**: developer's guide

  - **checkpoints**: model format dependent on the code that created the model
  ```
    /
    checkpoint
    model.ckpt-?????.data-?????-of-?????
    model.ckpt-?????.index
    model.ckpt-?????.meta
  ```
  - **SavedModel**: model format independent of the code that created the model
  ```
    assets/
    assets.extra/
    variables/
        variables.data-?????-of-?????
        variables.index
    saved_model.pb|saved_model.pbtxt
  ```
  - **frozen model**: single file graph def (variables converted into inline constants)
  ```
    model.pb
  ```

Customize the artifact output directory. Two sub-directories will be created in this location if they don't exist yet:
 * `frozen_graph_assets`
 * `web_assets`

In [None]:
base_output_dir = "/tmp"

In [None]:
frozen_graph_dir = "{}/{}".format(base_output_dir, "frozen_graph_assets")
web_asset_dir = "{}/{}".format(base_output_dir, "web_assets")

for dir in [frozen_graph_dir, web_asset_dir]: 
    try:
     pathlib.Path(dir).mkdir(exist_ok=True)
    except FileExistsError:
     print("Output location {} already exists and is not a directory.".format(dir))    
    
now = datetime.datetime.now().strftime("%y-%m-%d-%H-%M")
frozen_graph_filename = "frozen_graph_{}.pb".format(now)
frozen_graph_stripped_filename = "frozen_graph_stripped_{}.pb".format(now)
frozen_graph_path = "{}/{}".format(frozen_graph_dir, frozen_graph_filename)
frozen_graph_stripped_path = "{}/{}".format(frozen_graph_dir, frozen_graph_stripped_filename)

print('The frozen graph files for this model will be stored in `{}`'.format(frozen_graph_dir))
print('The Tensorflow.js files for this model will be stored in `{}`'.format(web_asset_dir))

## Create frozen graph from checkpoint

In [None]:
import tensorflow as tf
from core.backend import ModelWrapper

m = ModelWrapper()
sess = m.sess

input_graph_def = sess.graph.as_graph_def()


# choose outputs wanted (most of the time you will only be choosing the prediction node)
output_node_names = 'softmax,lstm/initial_state,lstm/state'


# convert_variables_to_constants function in graph_util to pass the session, graph_def and the ends to save.
output_graph_def = tf.graph_util.convert_variables_to_constants(
    sess, # The session
    input_graph_def, # input_graph_def is useful for retrieving the nodes 
    output_node_names.split(",")  
)

# serialize and write the output (frozen) graph to the file system

with tf.gfile.GFile(frozen_graph_path, "wb") as f:
    f.write(output_graph_def.SerializeToString())
    print('Saved frozen graph: ' + frozen_graph_path)
 
sess.close()

## Load frozen graph and run inference

In [None]:
# load the frozen file and parse it to get the unserialized graph_def
with tf.gfile.GFile(frozen_graph_path, "rb") as f:
    restored_graph_def = tf.GraphDef()
    restored_graph_def.ParseFromString(f.read())

In [None]:
# import the graph_def using tf.import_graph_def function
from core import inference_wrapper

with tf.Graph().as_default() as graph:
    model = inference_wrapper.InferenceWrapper()
    tf.import_graph_def(
        restored_graph_def,
        input_map=None,
        return_elements=None,
        name=""
    )

sess = tf.Session(graph=graph)

In [None]:
import math
import logging
from core.inference_utils import vocabulary
from core.inference_utils import caption_generator

logger = logging.getLogger()

# path to the word counts file (in repo)
VOCAB_FILE = './assets/word_counts.txt'

# run prediction
def predict(sess, model, image_data):
    # Create the vocabulary.
    vocab = vocabulary.Vocabulary(VOCAB_FILE)

    # Prepare the caption generator. Here we are implicitly using the default
    # beam search parameters. See caption_generator.py for a description of the
    # available beam search parameters.
    generator = caption_generator.CaptionGenerator(model, vocab)

    captions = generator.beam_search(sess, image_data)

    results = []
    for i, caption in enumerate(captions):
        # Ignore begin and end words.
        sentence = [vocab.id_to_word(w) for w in caption.sentence[1:-1]]
        sentence = " ".join(sentence)
        # print("  %d) %s (p=%f)" % (i, sentence, math.exp(caption.logprob)))
        results.append((i, sentence, math.exp(caption.logprob)))

    return results


# helper function to get raw image data from given img path
def get_image_data(path_to_img):
    image_handle = open(path_to_img, 'rb')
    raw_image_data = image_handle.read()
    image_handle.close()
    return raw_image_data


Display validation image and run prediction using the frozen graph. You can use any image.

In [None]:
validation_image_path = './assets/soccer.jpg'
# display image
from IPython.display import Image
Image(url=validation_image_path)

In [None]:
with open(validation_image_path, 'rb') as image:
    # run prediction on image
    res = predict(sess, model, image.read())
    pp.pprint(res)

<br>
<hr>

# Converting to a web-friendly format

[https://github.com/tensorflow/tfjs-converter](https://github.com/tensorflow/tfjs-converter)


```
tensorflowjs_converter \
    --input_format=tf_frozen_model \
    --output_node_names='softmax,lstm/initial_state,lstm/state' \
    --saved_model_tags=serve \
    /path/to/frozen/model.pb \
    /path/to/web_asset_output_dir
```


## Load frozen graph 

In [None]:
print("Frozen graph file: {}".format(frozen_graph_path))
print(" File size: {} MiB".format(os.path.getsize(frozen_graph_path) >> 20))

# load the frozen file and parse it to get the unserialized graph_def
with tf.gfile.GFile(frozen_graph_path, "rb") as f:
    restored_graph_def = tf.GraphDef()
    restored_graph_def.ParseFromString(f.read())

## Strip unused nodes from graph

In [None]:
from tensorflow.python.tools import strip_unused_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.platform import gfile

# TODO figure out the appropriate input and output nodes required
# input_node_names = ['Mul']
# output_node_names = ['softmax', 'lstm/initial_state', 'lstm/state']
input_node_names = []
output_node_names = ['softmax']

gdef = strip_unused_lib.strip_unused(
        input_graph_def = restored_graph_def,
        input_node_names = input_node_names,
        output_node_names = output_node_names,
        placeholder_type_enum = dtypes.float32.as_datatype_enum)

with gfile.GFile(frozen_graph_stripped_path, "wb") as f:
    f.write(gdef.SerializeToString())
    print("Stripped frozen graph file: {}".format(frozen_graph_stripped_path))
    print(" File size: {} MiB".format(os.path.getsize(frozen_graph_stripped_path) >> 20))

## Convert frozen graph to TensorFlow.js

In [None]:
# set appropriate frozen model path and desired output path for web format

!tensorflowjs_converter \
    --input_format=tf_frozen_model \
    --output_node_names='softmax' \
    --saved_model_tags=serve \
    {frozen_graph_stripped_path} \
    {web_asset_dir}


In [None]:
print("Web asset directory {}:".format(web_asset_dir))
web_assets = os.listdir(web_asset_dir)
web_assets.sort()
for file in web_assets:
    file_stat = os.stat("{}/{}".format(web_asset_dir,file))
    print(" {} {} {:>20}".format(file.ljust(30), time.ctime(file_stat.st_mtime), file_stat.st_size))

Use the artifacts listed above in your Tensorflow.js application.