In [0]:
# 切换目录
import os
os.chdir("models/research/object_detection") 

In [3]:
# 导出相关的模块

import numpy as np
import os
import six.moves.urllib as urllib
import sys
import tarfile
import tensorflow as tf

# 查看打印 TensorFlow 版本号
print("tensorflow {}".format(tf.__version__))

tensorflow 1.10.0


In [0]:
# 将上层目录导入进来，为了执行下面的模块导入
sys.path.append("..")
from object_detection.utils import ops as utils_ops

# 导入 Object Detection 的 utils 模块 
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as vis_util

In [0]:
# 模型的名称和下载地址拼接
MODEL_NAME = 'mask_rcnn_inception_v2_coco_2018_01_28'
MODEL_FILE = MODEL_NAME + '.tar.gz'
DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/'

# 模型下载解压后的目录里，冻结的graph，此 graph 保存了预训练网络的架构，这在对象检测时是经常这么用的
PATH_TO_FROZEN_GRAPH = MODEL_NAME + '/frozen_inference_graph.pb'

# mscoco_label_map.pbtxt 保存了类别和索引的映射关系
PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt')

NUM_CLASSES = 90

In [0]:
# 模型下载
opener = urllib.request.URLopener()
opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE)

# 下载后解压
tar_file = tarfile.open(MODEL_FILE)
for file in tar_file.getmembers():
    file_name = os.path.basename(file.name)
    if 'frozen_inference_graph.pb' in file_name:
        tar_file.extract(file, os.getcwd())

In [0]:
# 加载已冻结的预训练模型到内存中
detection_graph = tf.Graph()
with detection_graph.as_default():
    od_graph_def = tf.GraphDef()
    with tf.gfile.GFile(PATH_TO_FROZEN_GRAPH, 'rb') as fid:
        serialized_graph = fid.read()
        od_graph_def.ParseFromString(serialized_graph)
        tf.import_graph_def(od_graph_def, name='')

In [0]:
# 加载类别与索引的映射关系
# 标签映射将索引映射到类别名称，因此当我们的卷积网络预测5时，我们知道这对应于飞机。 
# 这里我们使用内部 util 函数，但任何返回字典的映射索引到适当的字符串标签都是可以的
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)

In [0]:
# 导入 OpenCV2
import cv2

# 初始化 web camera
cap = cv2.VideoCapture(0)

# 获取 TensorFlow 默认计算图
with detection_graph.as_default():
    # 获取 TensorFlow 会话
    with tf.Session(graph=detection_graph) as sess:
        # 开启无限循环的从 web camera 上
        ret = True
        while (ret):
            # 从 web cameras 读取图像
            ret,image_np = cap.read()
            
            # 由于模型期望的图片维度是四位 shape = [1, None, None, 3]， 所以这里将图片扩展为四维
            image_np_expanded = np.expand_dims(image_np, axis=0)
            
            # 获取图片的张量
            image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
            # 获取图片中检测到的对象边界框的张量
            boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
            # 获取图片中检测到的对象的置信水平值的张量，该值会与该对象的类别标签一起显示
            scores = detection_graph.get_tensor_by_name('detection_scores:0')
            # 获取图片中检测到的对象的类别的张量
            classes = detection_graph.get_tensor_by_name('detection_classes:0')
            # 获取图片中检测到的对象个数的张量
            num_detections = detection_graph.get_tensor_by_name('num_detections:0')

            # 开始检测对象
            (boxes, scores, classes, num_detections) = sess.run(
                [boxes, scores, classes, num_detections],
                feed_dict={image_tensor: image_np_expanded})
            
            # 将检测结果可视化到图片上
            vis_util.visualize_boxes_and_labels_on_image_array(
                image_np,
                np.squeeze(boxes),
                np.squeeze(classes).astype(np.int32),
                np.squeeze(scores),
                category_index,
                use_normalized_coordinates=True,
                line_thickness=8)
            
            # 将结果的图片显示到窗口上
            cv2.imshow('image', cv2.resize(image_np, (960,700)))
            
            # 当按下q的时候，退出
            if cv2.waitKey(25) & 0xFF == ord('q'):
                cv2.destroyAllWindows()
                cap.release()
                break

