New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TensorFlow image_processing component #17795
Merged
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
0cae593
initial tensorflow image_processing component
hunterjm cf1472b
linting fixes
hunterjm b694366
make displayed attribute a summary of objects
hunterjm bf882de
fix missed merge conflict and add warning supression back in for CPU …
hunterjm 6005ddd
restructure tensorflow component to install on the fly, remove from D…
hunterjm d7272cf
add both matches and summary as attributes
hunterjm f441355
address review comments
hunterjm 1ee32ed
do not use deps folder as default, as it should only be managed by HA…
hunterjm File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
347 changes: 347 additions & 0 deletions
347
homeassistant/components/image_processing/tensorflow.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,347 @@ | ||
""" | ||
Component that performs TensorFlow classification on images. | ||
|
||
For a quick start, pick a pre-trained COCO model from: | ||
https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md | ||
|
||
For more details about this platform, please refer to the documentation at | ||
https://home-assistant.io/components/image_processing.tensorflow/ | ||
""" | ||
import logging | ||
import sys | ||
import os | ||
|
||
import voluptuous as vol | ||
|
||
from homeassistant.components.image_processing import ( | ||
CONF_CONFIDENCE, CONF_ENTITY_ID, CONF_NAME, CONF_SOURCE, PLATFORM_SCHEMA, | ||
ImageProcessingEntity) | ||
from homeassistant.core import split_entity_id | ||
from homeassistant.helpers import template | ||
import homeassistant.helpers.config_validation as cv | ||
|
||
REQUIREMENTS = ['numpy==1.15.3', 'pillow==5.2.0', | ||
'protobuf==3.6.1', 'tensorflow==1.11.0'] | ||
|
||
_LOGGER = logging.getLogger(__name__) | ||
|
||
ATTR_MATCHES = 'matches' | ||
ATTR_SUMMARY = 'summary' | ||
ATTR_TOTAL_MATCHES = 'total_matches' | ||
|
||
CONF_FILE_OUT = 'file_out' | ||
CONF_MODEL = 'model' | ||
CONF_GRAPH = 'graph' | ||
CONF_LABELS = 'labels' | ||
CONF_MODEL_DIR = 'model_dir' | ||
CONF_CATEGORIES = 'categories' | ||
CONF_CATEGORY = 'category' | ||
CONF_AREA = 'area' | ||
CONF_TOP = 'top' | ||
CONF_LEFT = 'left' | ||
CONF_BOTTOM = 'bottom' | ||
CONF_RIGHT = 'right' | ||
|
||
AREA_SCHEMA = vol.Schema({ | ||
vol.Optional(CONF_TOP, default=0): cv.small_float, | ||
vol.Optional(CONF_LEFT, default=0): cv.small_float, | ||
vol.Optional(CONF_BOTTOM, default=1): cv.small_float, | ||
vol.Optional(CONF_RIGHT, default=1): cv.small_float | ||
}) | ||
|
||
CATEGORY_SCHEMA = vol.Schema({ | ||
vol.Required(CONF_CATEGORY): cv.string, | ||
vol.Optional(CONF_AREA): AREA_SCHEMA | ||
}) | ||
|
||
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({ | ||
vol.Optional(CONF_FILE_OUT, default=[]): | ||
vol.All(cv.ensure_list, [cv.template]), | ||
vol.Required(CONF_MODEL): vol.Schema({ | ||
vol.Required(CONF_GRAPH): cv.isfile, | ||
vol.Optional(CONF_LABELS): cv.isfile, | ||
vol.Optional(CONF_MODEL_DIR): cv.isdir, | ||
vol.Optional(CONF_AREA): AREA_SCHEMA, | ||
vol.Optional(CONF_CATEGORIES, default=[]): | ||
vol.All(cv.ensure_list, [vol.Any( | ||
cv.string, | ||
CATEGORY_SCHEMA | ||
)]) | ||
}) | ||
}) | ||
|
||
|
||
def draw_box(draw, box, img_width, | ||
img_height, text='', color=(255, 255, 0)): | ||
"""Draw bounding box on image.""" | ||
ymin, xmin, ymax, xmax = box | ||
(left, right, top, bottom) = (xmin * img_width, xmax * img_width, | ||
ymin * img_height, ymax * img_height) | ||
draw.line([(left, top), (left, bottom), (right, bottom), | ||
(right, top), (left, top)], width=5, fill=color) | ||
if text: | ||
draw.text((left, abs(top-15)), text, fill=color) | ||
|
||
|
||
def setup_platform(hass, config, add_entities, discovery_info=None): | ||
"""Set up the TensorFlow image processing platform.""" | ||
model_config = config.get(CONF_MODEL) | ||
model_dir = model_config.get(CONF_MODEL_DIR) \ | ||
or hass.config.path('tensorflow') | ||
labels = model_config.get(CONF_LABELS) \ | ||
or hass.config.path('tensorflow', 'object_detection', | ||
'data', 'mscoco_label_map.pbtxt') | ||
|
||
# Make sure locations exist | ||
if not os.path.isdir(model_dir) or not os.path.exists(labels): | ||
_LOGGER.error("Unable to locate tensorflow models or label map.") | ||
return | ||
|
||
# append custom model path to sys.path | ||
sys.path.append(model_dir) | ||
|
||
try: | ||
# Verify that the TensorFlow Object Detection API is pre-installed | ||
# pylint: disable=unused-import,unused-variable | ||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | ||
import tensorflow as tf # noqa | ||
from object_detection.utils import label_map_util # noqa | ||
except ImportError: | ||
# pylint: disable=line-too-long | ||
_LOGGER.error( | ||
"No TensorFlow Object Detection library found! Install or compile " | ||
"for your system following instructions here: " | ||
"https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/installation.md") # noqa | ||
return | ||
|
||
try: | ||
# Display warning that PIL will be used if no OpenCV is found. | ||
# pylint: disable=unused-import,unused-variable | ||
import cv2 # noqa | ||
except ImportError: | ||
_LOGGER.warning("No OpenCV library found. " | ||
"TensorFlow will process image with " | ||
"PIL at reduced resolution.") | ||
|
||
# setup tensorflow graph, session, and label map to pass to processor | ||
# pylint: disable=no-member | ||
detection_graph = tf.Graph() | ||
with detection_graph.as_default(): | ||
od_graph_def = tf.GraphDef() | ||
with tf.gfile.GFile(model_config.get(CONF_GRAPH), 'rb') as fid: | ||
serialized_graph = fid.read() | ||
od_graph_def.ParseFromString(serialized_graph) | ||
tf.import_graph_def(od_graph_def, name='') | ||
|
||
session = tf.Session(graph=detection_graph) | ||
label_map = label_map_util.load_labelmap(labels) | ||
categories = label_map_util.convert_label_map_to_categories( | ||
label_map, max_num_classes=90, use_display_name=True) | ||
category_index = label_map_util.create_category_index(categories) | ||
|
||
entities = [] | ||
|
||
for camera in config[CONF_SOURCE]: | ||
entities.append(TensorFlowImageProcessor( | ||
hass, camera[CONF_ENTITY_ID], camera.get(CONF_NAME), | ||
session, detection_graph, category_index, config)) | ||
|
||
add_entities(entities) | ||
|
||
|
||
class TensorFlowImageProcessor(ImageProcessingEntity): | ||
"""Representation of an TensorFlow image processor.""" | ||
|
||
def __init__(self, hass, camera_entity, name, session, detection_graph, | ||
category_index, config): | ||
"""Initialize the TensorFlow entity.""" | ||
model_config = config.get(CONF_MODEL) | ||
self.hass = hass | ||
self._camera_entity = camera_entity | ||
if name: | ||
self._name = name | ||
else: | ||
self._name = "TensorFlow {0}".format( | ||
split_entity_id(camera_entity)[1]) | ||
self._session = session | ||
self._graph = detection_graph | ||
self._category_index = category_index | ||
self._min_confidence = config.get(CONF_CONFIDENCE) | ||
self._file_out = config.get(CONF_FILE_OUT) | ||
|
||
# handle categories and specific detection areas | ||
categories = model_config.get(CONF_CATEGORIES) | ||
self._include_categories = [] | ||
self._category_areas = {} | ||
for category in categories: | ||
if isinstance(category, dict): | ||
category_name = category.get(CONF_CATEGORY) | ||
category_area = category.get(CONF_AREA) | ||
self._include_categories.append(category_name) | ||
self._category_areas[category_name] = [0, 0, 1, 1] | ||
if category_area: | ||
self._category_areas[category_name] = [ | ||
category_area.get(CONF_TOP), | ||
category_area.get(CONF_LEFT), | ||
category_area.get(CONF_BOTTOM), | ||
category_area.get(CONF_RIGHT) | ||
] | ||
else: | ||
self._include_categories.append(category) | ||
self._category_areas[category] = [0, 0, 1, 1] | ||
|
||
# Handle global detection area | ||
self._area = [0, 0, 1, 1] | ||
area_config = model_config.get(CONF_AREA) | ||
if area_config: | ||
self._area = [ | ||
area_config.get(CONF_TOP), | ||
area_config.get(CONF_LEFT), | ||
area_config.get(CONF_BOTTOM), | ||
area_config.get(CONF_RIGHT) | ||
] | ||
|
||
template.attach(hass, self._file_out) | ||
|
||
self._matches = {} | ||
self._total_matches = 0 | ||
self._last_image = None | ||
|
||
@property | ||
def camera_entity(self): | ||
"""Return camera entity id from process pictures.""" | ||
return self._camera_entity | ||
|
||
@property | ||
def name(self): | ||
"""Return the name of the image processor.""" | ||
return self._name | ||
|
||
@property | ||
def state(self): | ||
"""Return the state of the entity.""" | ||
return self._total_matches | ||
|
||
@property | ||
def device_state_attributes(self): | ||
"""Return device specific state attributes.""" | ||
return { | ||
ATTR_MATCHES: self._matches, | ||
ATTR_SUMMARY: {category: len(values) | ||
for category, values in self._matches.items()}, | ||
ATTR_TOTAL_MATCHES: self._total_matches | ||
} | ||
|
||
def _save_image(self, image, matches, paths): | ||
from PIL import Image, ImageDraw | ||
import io | ||
img = Image.open(io.BytesIO(bytearray(image))).convert('RGB') | ||
img_width, img_height = img.size | ||
draw = ImageDraw.Draw(img) | ||
|
||
# Draw custom global region/area | ||
if self._area != [0, 0, 1, 1]: | ||
draw_box(draw, self._area, | ||
img_width, img_height, | ||
"Detection Area", (0, 255, 255)) | ||
|
||
for category, values in matches.items(): | ||
# Draw custom category regions/areas | ||
if self._category_areas[category] != [0, 0, 1, 1]: | ||
label = "{} Detection Area".format(category.capitalize()) | ||
draw_box(draw, self._category_areas[category], img_width, | ||
img_height, label, (0, 255, 0)) | ||
|
||
# Draw detected objects | ||
for instance in values: | ||
label = "{0} {1:.1f}%".format(category, instance['score']) | ||
draw_box(draw, instance['box'], | ||
img_width, img_height, | ||
label, (255, 255, 0)) | ||
|
||
for path in paths: | ||
_LOGGER.info("Saving results image to %s", path) | ||
img.save(path) | ||
|
||
def process_image(self, image): | ||
"""Process the image.""" | ||
import numpy as np | ||
|
||
try: | ||
import cv2 # pylint: disable=import-error | ||
img = cv2.imdecode( | ||
np.asarray(bytearray(image)), cv2.IMREAD_UNCHANGED) | ||
inp = img[:, :, [2, 1, 0]] # BGR->RGB | ||
inp_expanded = inp.reshape(1, inp.shape[0], inp.shape[1], 3) | ||
except ImportError: | ||
from PIL import Image | ||
import io | ||
img = Image.open(io.BytesIO(bytearray(image))).convert('RGB') | ||
img.thumbnail((460, 460), Image.ANTIALIAS) | ||
img_width, img_height = img.size | ||
inp = np.array(img.getdata()).reshape( | ||
(img_height, img_width, 3)).astype(np.uint8) | ||
inp_expanded = np.expand_dims(inp, axis=0) | ||
|
||
image_tensor = self._graph.get_tensor_by_name('image_tensor:0') | ||
boxes = self._graph.get_tensor_by_name('detection_boxes:0') | ||
scores = self._graph.get_tensor_by_name('detection_scores:0') | ||
classes = self._graph.get_tensor_by_name('detection_classes:0') | ||
boxes, scores, classes = self._session.run( | ||
[boxes, scores, classes], | ||
feed_dict={image_tensor: inp_expanded}) | ||
boxes, scores, classes = map(np.squeeze, [boxes, scores, classes]) | ||
classes = classes.astype(int) | ||
|
||
matches = {} | ||
total_matches = 0 | ||
for box, score, obj_class in zip(boxes, scores, classes): | ||
score = score * 100 | ||
boxes = box.tolist() | ||
|
||
# Exclude matches below min confidence value | ||
if score < self._min_confidence: | ||
continue | ||
|
||
# Exclude matches outside global area definition | ||
if (boxes[0] < self._area[0] or boxes[1] < self._area[1] | ||
or boxes[2] > self._area[2] or boxes[3] > self._area[3]): | ||
continue | ||
|
||
category = self._category_index[obj_class]['name'] | ||
|
||
# Exclude unlisted categories | ||
if (self._include_categories | ||
and category not in self._include_categories): | ||
continue | ||
|
||
# Exclude matches outside category specific area definition | ||
if (self._category_areas | ||
and (boxes[0] < self._category_areas[category][0] | ||
or boxes[1] < self._category_areas[category][1] | ||
or boxes[2] > self._category_areas[category][2] | ||
or boxes[3] > self._category_areas[category][3])): | ||
continue | ||
|
||
# If we got here, we should include it | ||
if category not in matches.keys(): | ||
matches[category] = [] | ||
matches[category].append({ | ||
'score': float(score), | ||
'box': boxes | ||
}) | ||
total_matches += 1 | ||
|
||
# Save Images | ||
if total_matches and self._file_out: | ||
paths = [] | ||
for path_template in self._file_out: | ||
if isinstance(path_template, template.Template): | ||
paths.append(path_template.render( | ||
camera_entity=self._camera_entity)) | ||
else: | ||
paths.append(path_template) | ||
self._save_image(image, matches, paths) | ||
|
||
self._matches = matches | ||
self._total_matches = total_matches |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ehhhrrrrr
We should be able to specify the paths to look to tensorflow?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some requirements that can not be installed via pip. The object detection portion of the library is not included in the standard TensorFlow wheel. This is to include the protobuf models and util functions created as a part of the tensorflow install shell script.