In [4]:
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Simple image classification with Inception.
Run image classification with Inception trained on ImageNet 2012 Challenge data
set.
This program creates a graph from a saved GraphDef protocol buffer,
and runs inference on an input JPEG image. It outputs human readable
strings of the top 5 predictions along with their probabilities.
Change the --image_file argument to any jpg image to compute a
classification of that image.
Please see the tutorial and website for a detailed description of how
to use this script to perform image recognition.
https://tensorflow.org/tutorials/image_recognition/
"""
#绝对引用
from __future__ import absolute_import 
#精确除法
from __future__ import division
#在python2的环境下，可以使用python3的print函数，确保兼容
from __future__ import print_function

import argparse
#获取文件的属性
import os.path
import re
import sys
import tarfile

import numpy as np
from matplotlib import pyplot as plt
from six.moves import urllib
# import tensorflow as tf
import tensorflow.compat.v1 as tf
FLAGS = None

# pylint: disable=line-too-long
DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
# pylint: enable=line-too-long


##将类别ID转换为人类易读的标签
class NodeLookup(object):
  """Converts integer node ID's to human readable labels."""

  def __init__(self, 
                uid_chinese_lookup_path, 
                model_dir, 
                label_lookup_path=None,
                uid_lookup_path=None):
    if not label_lookup_path:
      label_lookup_path = os.path.join(
          model_dir, 'imagenet_2012_challenge_label_map_proto.pbtxt')
    if not uid_lookup_path:
      uid_lookup_path = os.path.join(
          model_dir, 'imagenet_synset_to_human_label_map.txt')
    #self.node_lookup = self.load(label_lookup_path, uid_lookup_path)
    self.node_lookup = self.load_chinese_map(uid_chinese_lookup_path)
    
  #加载从字符串UID到人类可读字符串的映射
  def load(self, label_lookup_path, uid_lookup_path):
  
    if not tf.gfile.Exists(uid_lookup_path):
      tf.logging.fatal('File does not exist %s', uid_lookup_path)
    if not tf.gfile.Exists(label_lookup_path):
      tf.logging.fatal('File does not exist %s', label_lookup_path)

    # 加载分类字符串对应分类名称的文件
    proto_as_ascii_lines = tf.gfile.GFile(uid_lookup_path).readlines()
    uid_to_human = {}
    #p = re.compile(r'[n\d]*[ \S,]*')
    p = re.compile(r'(n\d*)\t(.*)')
    for line in proto_as_ascii_lines:
      parsed_items = p.findall(line)
      print(parsed_items)
      uid = parsed_items[0]
      human_string = parsed_items[1]
      uid_to_human[uid] = human_string
    #加载从字符串UID到整数节点ID的映射。
    node_id_to_uid = {}
    proto_as_ascii = tf.gfile.GFile(label_lookup_path).readlines()
    for line in proto_as_ascii:
      if line.startswith('  target_class:'):
        #获取分类编号1-1000
        target_class = int(line.split(': ')[1])
      if line.startswith('  target_class_string:'):
        #获取编号字符串n********
        target_class_string = line.split(': ')[1]
        #保存分类编号1-1000与编号字符串n********映射关系
        node_id_to_uid[target_class] = target_class_string[1:-2]

    #建立分类编号1-1000对应分类名称的映射关系
    node_id_to_name = {}
    for key, val in node_id_to_uid.items():
      if val not in uid_to_human:
        tf.logging.fatal('Failed to locate: %s', val)
      #获取分类名称
      name = uid_to_human[val]
      #建立分类编号1-1000到分类名称的映射关系
      node_id_to_name[key] = name

    return node_id_to_name

  def load_chinese_map(self, uid_chinese_lookup_path):
    proto_as_ascii_lines = tf.gfile.GFile(uid_chinese_lookup_path).readlines()
    uid_to_human = {}
    p = re.compile(r'(\d*)\t(.*)')
    for line in proto_as_ascii_lines:
      parsed_items = p.findall(line)
      #print(parsed_items)
      uid = parsed_items[0][0]
      human_string = parsed_items[0][1]
      uid_to_human[int(uid)] = human_string
    return uid_to_human

#传入分类编号1-1000返回分类名称
  def id_to_string(self, node_id):
    if node_id not in self.node_lookup:
      return ''
    return self.node_lookup[node_id]

#创建一个图来存放google训练好的模型
def create_graph(model_dir):
  """Creates a graph from saved GraphDef file and returns a saver."""
  # Creates graph from saved graph_def.pb.
  with tf.gfile.FastGFile(os.path.join(
      model_dir, 'classify_image_graph_def.pb'), 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    _ = tf.import_graph_def(graph_def, name='')


def run_inference_on_image(image):
  """Runs inference on an image.
  Args:
    image: Image file name.
  Returns:
    Nothing
  """
  if not tf.gfile.Exists(image):
    tf.logging.fatal('File does not exist %s', image)
  image_data = tf.gfile.FastGFile(image, 'rb').read()

  # Creates graph from saved GraphDef.
  create_graph(FLAGS.model_dir)

  with tf.Session() as sess:
        
        # Some useful tensors:
        # 'softmax:0': A tensor containing the normalized prediction across
        #   1000 labels.
        # 'pool_3:0': A tensor containing the next-to-last layer containing 2048
        #   float description of the image.
        # 'DecodeJpeg/contents:0': A tensor containing a string providing JPEG
        #   encoding of the image.
        # Runs the softmax tensor by feeding the image_data as input to the graph.
        softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
        predictions = sess.run(softmax_tensor,
                               {'DecodeJpeg/contents:0': image_data})
        predictions = np.squeeze(predictions)

        # Creates node ID --> chinese string lookup.
        node_lookup = NodeLookup(uid_chinese_lookup_path='./data/imagenet_2012_challenge_label_chinese_map.pbtxt', \
                                    model_dir=FLAGS.model_dir)

        top_k = predictions.argsort()[-FLAGS.num_top_predictions:][::-1]
        for node_id in top_k:
          human_string = node_lookup.id_to_string(node_id)
          score = predictions[node_id]
          print('%s (score = %.5f)' % (human_string, score))
          #print('node_id: %s' %(node_id))

# 检测本地是否有数据集
def maybe_download_and_extract():
  """Download and extract model tar file."""
  dest_directory = FLAGS.model_dir
  if not os.path.exists(dest_directory):
    os.makedirs(dest_directory)
  # 从URL中获得文件名
  filename = DATA_URL.split('/')[-1]
  # 合并文件路径
  filepath = os.path.join(dest_directory, filename)
  if not os.path.exists(filepath):
        # 定义下载过程中打印日志的回调函数
    def _progress(count, block_size, total_size):
      sys.stdout.write('\r>> Downloading %s %.1f%%' % (
          filename, float(count * block_size) / float(total_size) * 100.0))
      sys.stdout.flush()
    # 下载数据集  
    filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
    print()
    # 获得文件信息
    statinfo = os.stat(filepath)
    print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
    # 解压缩
  tarfile.open(filepath, 'r:gz').extractall(dest_directory)


def main(_):
  maybe_download_and_extract()

  image = (FLAGS.image_file if FLAGS.image_file else
           os.path.join(FLAGS.model_dir, '2.png'))
  run_inference_on_image(image)


if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  # classify_image_graph_def.pb:
  #   Binary representation of the GraphDef protocol buffer.
  # imagenet_synset_to_human_label_map.txt:
  #   Map from synset ID to a human readable string.
  # imagenet_2012_challenge_label_map_proto.pbtxt:
  #   Text representation of a protocol buffer mapping a label to synset ID.

  #命令项选项与参数解析
  parser.add_argument(
      '--model_dir',
      type=str,
      default='img',
      help="""\
      Path to classify_image_graph_def.pb,
      imagenet_synset_to_human_label_map.txt, and
      imagenet_2012_challenge_label_map_proto.pbtxt.\
      """
  )
  parser.add_argument(
      '--image_file',
      type=str,
      default='',
      help='Absolute path to image file.'
  )
  parser.add_argument(
      '--num_top_predictions',
      type=int,
      default=5,
      help='Display this many predictions.'
  )
#对运行命令行时传进来的参数进行解析，如果传进来的参数是之前被add到parser中的，则被传给FLAGS，否则讲传给unpared
  FLAGS, unparsed = parser.parse_known_args()
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)


UnboundLocalError: local variable 'image' referenced before assignment