# Tensorflow Notebook - Classification

In [1]:
import os.path
import re
import sys
import tarfile

import numpy as np
from six.moves import urllib
import tensorflow as tf

In [2]:
model_dir = '/tmp/imagenet'
image_file = ''
num_top_predictions = 5

DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'

In [3]:
def maybe_download_and_extract():
    """
    データをダウンロードして、解凍する
    """
    dest_directory = model_dir
    if not os.path.exists(dest_directory):
        os.makedirs(dest_directory)
    filename = DATA_URL.split('/')[-1]
    filepath = os.path.join(dest_directory, filename)
    if not os.path.exists(filepath):
        def _progress(count, block_size, total_size):
            sys.stdout.write('\r>> Downloading %s %.1f%%' % (
                filename, float(count * block_size) / float(total_size) * 100.0))
            sys.stdout.flush()
        filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
        print()
        statinfo = os.stat(filepath)
        print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
    tarfile.open(filepath, 'r:gz').extractall(dest_directory)

In [4]:
class NodeLookup(object):
    """
    分類番号をラベルに変換
    """
    def __init__(self, label_lookup_path=None, uid_lookup_path=None):
        if not label_lookup_path:
            label_lookup_path = os.path.join(model_dir, 'imagenet_2012_challenge_label_map_proto.pbtxt')
        if not uid_lookup_path:
            uid_lookup_path = os.path.join(model_dir, 'imagenet_synset_to_human_label_map.txt')
        self.node_lookup = self.load(label_lookup_path, uid_lookup_path)

    def load(self, label_lookup_path, uid_lookup_path):
        """
        ラベルを読み込む
        
        Args:
            label_lookup_path: string UID to integer node ID
            uid_lookup_path: string UID to human-readable string
            
        Returns:
           
        """
        if not tf.gfile.Exists(uid_lookup_path):
            tf.logging.fatal('File does not exist %s', uid_lookup_path)
        if not tf.gfile.Exists(label_lookup_path):
            tf.logging.fatal('File does not exist %s', label_lookup_path)

        # UIDからラベルへのマッピング
        proto_as_ascii_lines = tf.gfile.GFile(uid_lookup_path).readlines()
        uid_to_human = {}
        p = re.compile(r'[n\d]*[\S,]*')
        for line in proto_as_ascii_lines:
            parsed_items = p.findall(line)
            uid = parsed_items[0]
            human_string = parsed_items[2]
            uid_to_human[uid] = human_string

        # UIDからnode IDへのマッピング
        node_id_to_uid = {}
        proto_as_ascii = tf.gfile.GFile(label_lookup_path).readlines()
        for line in proto_as_ascii:
            if line.startswith('  target_class:'):
                target_class = int(line.split(': ')[1])
            if line.startswith('  target_class_string:'):
                target_class_string = line.split(': ')[1]
                node_id_to_uid[target_class] = target_class_string[1:-2]

        # node IDからラベルへのマッピング
        node_id_to_name = {}
        for key, val in node_id_to_uid.items():
            if val not in uid_to_human:
                tf.logging.fatal('Failed to locate: %s', val)
            name = uid_to_human[val]
            node_id_to_name[key] = name

        return node_id_to_name

    def id_to_string(self, node_id):
        if node_id not in self.node_lookup:
            return ''
        return self.node_lookup[node_id]

In [5]:
def create_graph():
    """
    GraphDefを読み込み、saverを返却
    """
    with tf.gfile.FastGFile(os.path.join(
            model_dir, 'classify_image_graph_def.pb'), 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        _ = tf.import_graph_def(graph_def, name='')

def run_inference_on_image(image):
    """
    画像から推論を実行
    Args:
        image: 推論させる画像
    
    Returns:
        Nothings
    """
    if not tf.gfile.Exists(image):
        tf.logging.fatal('File does not exist %s', image)
    image_data = tf.gfile.FastGFile(image, 'rb').read()

    # 保存しているGraphDefからgraphを作成
    create_graph()

    with tf.Session() as sess:
        # Some useful tensors:
        #  'softmax:0': A tensor containing the normalized prediction across 1000 labels
        #  'pool_3:0': A tensor containing the next-to-last layer containing 2048 float description of the image.
        # 'DecodeJpeg/contents:0': A tensor contaning a string providing JPEG providing JPEG encoding of the images

        softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
        predictions = sess.run(softmax_tensor,
                                               {'DecodeJpeg/contents:0': image_data})
        predictions = np.squeeze(predictions)

        # node ID --> ラベル
        node_lookup = NodeLookup()

        top_k = predictions.argsort()[-num_top_predictions:][::-1]
        for node_id in top_k:
            human_string = node_lookup.id_to_string(node_id)
            score = predictions[node_id]
            print('%s (score = %.5f)' % (human_string, score))

In [10]:
maybe_download_and_extract()

>> Downloading inception-2015-12-05.tgz 100.0%
Succesfully downloaded inception-2015-12-05.tgz 88931400 bytes.


In [11]:
image = (image_file if image_file else
                os.path.join(model_dir, 'cropped_panda.jpg'))

In [12]:
run_inference_on_image(image)

giant (score = 0.89233)
indri, (score = 0.00859)
lesser (score = 0.00264)
custard (score = 0.00141)
earthstar (score = 0.00107)
