Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TensorFlow image_processing component #17795

Merged
merged 8 commits into from Nov 2, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions .coveragerc
Expand Up @@ -510,6 +510,7 @@ omit =
homeassistant/components/image_processing/dlib_face_detect.py
homeassistant/components/image_processing/dlib_face_identify.py
homeassistant/components/image_processing/seven_segments.py
homeassistant/components/image_processing/tensorflow.py
homeassistant/components/keyboard_remote.py
homeassistant/components/keyboard.py
homeassistant/components/light/avion.py
Expand Down
347 changes: 347 additions & 0 deletions homeassistant/components/image_processing/tensorflow.py
@@ -0,0 +1,347 @@
"""
Component that performs TensorFlow classification on images.

For a quick start, pick a pre-trained COCO model from:
https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md

For more details about this platform, please refer to the documentation at
https://home-assistant.io/components/image_processing.tensorflow/
"""
import logging
import sys
import os

import voluptuous as vol

from homeassistant.components.image_processing import (
CONF_CONFIDENCE, CONF_ENTITY_ID, CONF_NAME, CONF_SOURCE, PLATFORM_SCHEMA,
ImageProcessingEntity)
from homeassistant.core import split_entity_id
from homeassistant.helpers import template
import homeassistant.helpers.config_validation as cv

REQUIREMENTS = ['numpy==1.15.3', 'pillow==5.2.0',
'protobuf==3.6.1', 'tensorflow==1.11.0']

_LOGGER = logging.getLogger(__name__)

ATTR_MATCHES = 'matches'
ATTR_SUMMARY = 'summary'
ATTR_TOTAL_MATCHES = 'total_matches'

CONF_FILE_OUT = 'file_out'
CONF_MODEL = 'model'
CONF_GRAPH = 'graph'
CONF_LABELS = 'labels'
CONF_MODEL_DIR = 'model_dir'
CONF_CATEGORIES = 'categories'
CONF_CATEGORY = 'category'
CONF_AREA = 'area'
CONF_TOP = 'top'
CONF_LEFT = 'left'
CONF_BOTTOM = 'bottom'
CONF_RIGHT = 'right'

AREA_SCHEMA = vol.Schema({
vol.Optional(CONF_TOP, default=0): cv.small_float,
vol.Optional(CONF_LEFT, default=0): cv.small_float,
vol.Optional(CONF_BOTTOM, default=1): cv.small_float,
vol.Optional(CONF_RIGHT, default=1): cv.small_float
})

CATEGORY_SCHEMA = vol.Schema({
vol.Required(CONF_CATEGORY): cv.string,
vol.Optional(CONF_AREA): AREA_SCHEMA
})

PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({
vol.Optional(CONF_FILE_OUT, default=[]):
vol.All(cv.ensure_list, [cv.template]),
vol.Required(CONF_MODEL): vol.Schema({
vol.Required(CONF_GRAPH): cv.isfile,
vol.Optional(CONF_LABELS): cv.isfile,
vol.Optional(CONF_MODEL_DIR): cv.isdir,
vol.Optional(CONF_AREA): AREA_SCHEMA,
vol.Optional(CONF_CATEGORIES, default=[]):
vol.All(cv.ensure_list, [vol.Any(
cv.string,
CATEGORY_SCHEMA
)])
})
})


def draw_box(draw, box, img_width,
img_height, text='', color=(255, 255, 0)):
"""Draw bounding box on image."""
ymin, xmin, ymax, xmax = box
(left, right, top, bottom) = (xmin * img_width, xmax * img_width,
ymin * img_height, ymax * img_height)
draw.line([(left, top), (left, bottom), (right, bottom),
(right, top), (left, top)], width=5, fill=color)
if text:
draw.text((left, abs(top-15)), text, fill=color)


def setup_platform(hass, config, add_entities, discovery_info=None):
"""Set up the TensorFlow image processing platform."""
model_config = config.get(CONF_MODEL)
model_dir = model_config.get(CONF_MODEL_DIR) \
or hass.config.path('tensorflow')
labels = model_config.get(CONF_LABELS) \
or hass.config.path('tensorflow', 'object_detection',
'data', 'mscoco_label_map.pbtxt')

# Make sure locations exist
if not os.path.isdir(model_dir) or not os.path.exists(labels):
_LOGGER.error("Unable to locate tensorflow models or label map.")
return

# append custom model path to sys.path
sys.path.append(model_dir)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ehhhrrrrr

We should be able to specify the paths to look to tensorflow?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some requirements that can not be installed via pip. The object detection portion of the library is not included in the standard TensorFlow wheel. This is to include the protobuf models and util functions created as a part of the tensorflow install shell script.


try:
# Verify that the TensorFlow Object Detection API is pre-installed
# pylint: disable=unused-import,unused-variable
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf # noqa
from object_detection.utils import label_map_util # noqa
except ImportError:
# pylint: disable=line-too-long
_LOGGER.error(
"No TensorFlow Object Detection library found! Install or compile "
"for your system following instructions here: "
"https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/installation.md") # noqa
return

try:
# Display warning that PIL will be used if no OpenCV is found.
# pylint: disable=unused-import,unused-variable
import cv2 # noqa
except ImportError:
_LOGGER.warning("No OpenCV library found. "
"TensorFlow will process image with "
"PIL at reduced resolution.")

# setup tensorflow graph, session, and label map to pass to processor
# pylint: disable=no-member
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(model_config.get(CONF_GRAPH), 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')

session = tf.Session(graph=detection_graph)
label_map = label_map_util.load_labelmap(labels)
categories = label_map_util.convert_label_map_to_categories(
label_map, max_num_classes=90, use_display_name=True)
category_index = label_map_util.create_category_index(categories)

entities = []

for camera in config[CONF_SOURCE]:
entities.append(TensorFlowImageProcessor(
hass, camera[CONF_ENTITY_ID], camera.get(CONF_NAME),
session, detection_graph, category_index, config))

add_entities(entities)


class TensorFlowImageProcessor(ImageProcessingEntity):
"""Representation of an TensorFlow image processor."""

def __init__(self, hass, camera_entity, name, session, detection_graph,
category_index, config):
"""Initialize the TensorFlow entity."""
model_config = config.get(CONF_MODEL)
self.hass = hass
self._camera_entity = camera_entity
if name:
self._name = name
else:
self._name = "TensorFlow {0}".format(
split_entity_id(camera_entity)[1])
self._session = session
self._graph = detection_graph
self._category_index = category_index
self._min_confidence = config.get(CONF_CONFIDENCE)
self._file_out = config.get(CONF_FILE_OUT)

# handle categories and specific detection areas
categories = model_config.get(CONF_CATEGORIES)
self._include_categories = []
self._category_areas = {}
for category in categories:
if isinstance(category, dict):
category_name = category.get(CONF_CATEGORY)
category_area = category.get(CONF_AREA)
self._include_categories.append(category_name)
self._category_areas[category_name] = [0, 0, 1, 1]
if category_area:
self._category_areas[category_name] = [
category_area.get(CONF_TOP),
category_area.get(CONF_LEFT),
category_area.get(CONF_BOTTOM),
category_area.get(CONF_RIGHT)
]
else:
self._include_categories.append(category)
self._category_areas[category] = [0, 0, 1, 1]

# Handle global detection area
self._area = [0, 0, 1, 1]
area_config = model_config.get(CONF_AREA)
if area_config:
self._area = [
area_config.get(CONF_TOP),
area_config.get(CONF_LEFT),
area_config.get(CONF_BOTTOM),
area_config.get(CONF_RIGHT)
]

template.attach(hass, self._file_out)

self._matches = {}
self._total_matches = 0
self._last_image = None

@property
def camera_entity(self):
"""Return camera entity id from process pictures."""
return self._camera_entity

@property
def name(self):
"""Return the name of the image processor."""
return self._name

@property
def state(self):
"""Return the state of the entity."""
return self._total_matches

@property
def device_state_attributes(self):
"""Return device specific state attributes."""
return {
ATTR_MATCHES: self._matches,
ATTR_SUMMARY: {category: len(values)
for category, values in self._matches.items()},
ATTR_TOTAL_MATCHES: self._total_matches
}

def _save_image(self, image, matches, paths):
from PIL import Image, ImageDraw
import io
img = Image.open(io.BytesIO(bytearray(image))).convert('RGB')
img_width, img_height = img.size
draw = ImageDraw.Draw(img)

# Draw custom global region/area
if self._area != [0, 0, 1, 1]:
draw_box(draw, self._area,
img_width, img_height,
"Detection Area", (0, 255, 255))

for category, values in matches.items():
# Draw custom category regions/areas
if self._category_areas[category] != [0, 0, 1, 1]:
label = "{} Detection Area".format(category.capitalize())
draw_box(draw, self._category_areas[category], img_width,
img_height, label, (0, 255, 0))

# Draw detected objects
for instance in values:
label = "{0} {1:.1f}%".format(category, instance['score'])
draw_box(draw, instance['box'],
img_width, img_height,
label, (255, 255, 0))

for path in paths:
_LOGGER.info("Saving results image to %s", path)
img.save(path)

def process_image(self, image):
"""Process the image."""
import numpy as np

try:
import cv2 # pylint: disable=import-error
img = cv2.imdecode(
np.asarray(bytearray(image)), cv2.IMREAD_UNCHANGED)
inp = img[:, :, [2, 1, 0]] # BGR->RGB
inp_expanded = inp.reshape(1, inp.shape[0], inp.shape[1], 3)
except ImportError:
from PIL import Image
import io
img = Image.open(io.BytesIO(bytearray(image))).convert('RGB')
img.thumbnail((460, 460), Image.ANTIALIAS)
img_width, img_height = img.size
inp = np.array(img.getdata()).reshape(
(img_height, img_width, 3)).astype(np.uint8)
inp_expanded = np.expand_dims(inp, axis=0)

image_tensor = self._graph.get_tensor_by_name('image_tensor:0')
boxes = self._graph.get_tensor_by_name('detection_boxes:0')
scores = self._graph.get_tensor_by_name('detection_scores:0')
classes = self._graph.get_tensor_by_name('detection_classes:0')
boxes, scores, classes = self._session.run(
[boxes, scores, classes],
feed_dict={image_tensor: inp_expanded})
boxes, scores, classes = map(np.squeeze, [boxes, scores, classes])
classes = classes.astype(int)

matches = {}
total_matches = 0
for box, score, obj_class in zip(boxes, scores, classes):
score = score * 100
boxes = box.tolist()

# Exclude matches below min confidence value
if score < self._min_confidence:
continue

# Exclude matches outside global area definition
if (boxes[0] < self._area[0] or boxes[1] < self._area[1]
or boxes[2] > self._area[2] or boxes[3] > self._area[3]):
continue

category = self._category_index[obj_class]['name']

# Exclude unlisted categories
if (self._include_categories
and category not in self._include_categories):
continue

# Exclude matches outside category specific area definition
if (self._category_areas
and (boxes[0] < self._category_areas[category][0]
or boxes[1] < self._category_areas[category][1]
or boxes[2] > self._category_areas[category][2]
or boxes[3] > self._category_areas[category][3])):
continue

# If we got here, we should include it
if category not in matches.keys():
matches[category] = []
matches[category].append({
'score': float(score),
'box': boxes
})
total_matches += 1

# Save Images
if total_matches and self._file_out:
paths = []
for path_template in self._file_out:
if isinstance(path_template, template.Template):
paths.append(path_template.render(
camera_entity=self._camera_entity))
else:
paths.append(path_template)
self._save_image(image, matches, paths)

self._matches = matches
self._total_matches = total_matches
8 changes: 8 additions & 0 deletions requirements_all.txt
Expand Up @@ -669,6 +669,7 @@ nuheat==0.3.0

# homeassistant.components.binary_sensor.trend
# homeassistant.components.image_processing.opencv
# homeassistant.components.image_processing.tensorflow
# homeassistant.components.sensor.pollen
numpy==1.15.3

Expand Down Expand Up @@ -722,6 +723,7 @@ piglow==1.2.4
pilight==0.1.1

# homeassistant.components.camera.proxy
# homeassistant.components.image_processing.tensorflow
pillow==5.2.0

# homeassistant.components.dominos
Expand All @@ -747,6 +749,9 @@ proliphix==0.4.1
# homeassistant.components.prometheus
prometheus_client==0.2.0

# homeassistant.components.image_processing.tensorflow
protobuf==3.6.1

# homeassistant.components.sensor.systemmonitor
psutil==5.4.8

Expand Down Expand Up @@ -1464,6 +1469,9 @@ temescal==0.1
# homeassistant.components.sensor.temper
temperusb==1.5.3

# homeassistant.components.image_processing.tensorflow
tensorflow==1.11.0

# homeassistant.components.tesla
teslajsonpy==0.0.23

Expand Down
1 change: 1 addition & 0 deletions requirements_test_all.txt
Expand Up @@ -118,6 +118,7 @@ mficlient==0.3.0

# homeassistant.components.binary_sensor.trend
# homeassistant.components.image_processing.opencv
# homeassistant.components.image_processing.tensorflow
# homeassistant.components.sensor.pollen
numpy==1.15.3

Expand Down