-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
555 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from gabrieltool.statemachine.callable_zoo.base import record_kwargs, CallableBase, Null # noqa: F401 | ||
from gabrieltool.statemachine.callable_zoo import processor_zoo # noqa: F401 | ||
from gabrieltool.statemachine.callable_zoo import predicate_zoo # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import inspect | ||
from functools import wraps | ||
|
||
|
||
def record_kwargs(func): | ||
""" | ||
Automatically record constructor arguments | ||
>>> class process: | ||
... @record_kwargs | ||
... def __init__(self, cmd, reachable=False, user='root'): | ||
... pass | ||
>>> p = process('halt', True) | ||
>>> p.cmd, p.reachable, p.user | ||
('halt', True, 'root') | ||
""" | ||
names, varargs, keywords, defaults = inspect.getargspec(func) | ||
if defaults is None: | ||
defaults = () | ||
|
||
@wraps(func) | ||
def wrapper(self, *args, **kargs): | ||
func(self, *args, **kargs) | ||
kwargs = {} | ||
for name, default in zip(reversed(names), reversed(defaults)): | ||
kwargs[name] = default | ||
for name, arg in list(zip(names[1:], args)) + list(kargs.items()): | ||
kwargs[name] = arg | ||
setattr(self, 'kwargs', kwargs) | ||
|
||
return wrapper | ||
|
||
|
||
class CallableBase(): | ||
"""Base class for processor callables. | ||
Callables needs to be able to be serialized and de-serialized. | ||
Arguments: | ||
object {[type]} -- [description] | ||
Returns: | ||
[type] -- [description] | ||
""" | ||
|
||
def __init__(self): | ||
super().__init__() | ||
setattr(self, 'kwargs', {}) | ||
|
||
@classmethod | ||
def from_json(cls, json_obj): | ||
"""Create a class instance from a json object. | ||
Subclasses should overide this class depending on the input type of | ||
their constructor. | ||
""" | ||
return cls(**json_obj) | ||
|
||
def __eq__(self, other): | ||
if isinstance(other, CallableBase): | ||
return self.kwargs == other.kwargs | ||
return False | ||
|
||
def __ne__(self, other): | ||
return not self.__eq__(other) | ||
|
||
|
||
class Null(CallableBase): | ||
def __call__(self, *args, **kwargs): | ||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# -*- coding: utf-8 -*- | ||
"""Processing Function on State Machine Inputs. | ||
""" | ||
|
||
from gabrieltool.statemachine.callable_zoo import record_kwargs | ||
from gabrieltool.statemachine.callable_zoo import CallableBase | ||
|
||
|
||
class HasObjectClass(CallableBase): | ||
|
||
@record_kwargs | ||
def __init__(self, class_name): | ||
super().__init__() | ||
self.class_name = class_name | ||
|
||
def __call__(self, app_state): | ||
return self.class_name in app_state | ||
|
||
|
||
class Always(CallableBase): | ||
|
||
def __call__(self, app_state): | ||
return True |
3 changes: 3 additions & 0 deletions
3
gabrieltool/statemachine/callable_zoo/processor_zoo/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .base import DummyCallable, FasterRCNNOpenCVCallable # noqa: F401 | ||
from .containerized import FasterRCNNContainerCallable # noqa: F401 | ||
from .containerized import TFServingContainerCallable # noqa: F401 |
148 changes: 148 additions & 0 deletions
148
gabrieltool/statemachine/callable_zoo/processor_zoo/base.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
|
||
# -*- coding: utf-8 -*- | ||
"""Abstract base classes for processors | ||
""" | ||
import copy | ||
|
||
import cv2 | ||
import numpy as np | ||
from logzero import logger | ||
|
||
from gabrieltool.statemachine.callable_zoo import record_kwargs | ||
from gabrieltool.statemachine.callable_zoo import CallableBase | ||
|
||
|
||
def visualize_detections(img, results): | ||
"""Visualize object detection outputs. | ||
This is a helper function for debugging processor callables. | ||
The results should follow Gabrieltool's convention, which is | ||
Arguments: | ||
img {OpenCV Image} | ||
results {Dictionary} -- a dictionary of class_idx -> [[x1, y1, x2, y2, confidence, cls_idx],...] | ||
Returns: | ||
OpenCV Image -- Image with detected objects annotated | ||
""" | ||
img_detections = img.copy() | ||
for _, dets in results.items(): | ||
for i in range(len(dets)): | ||
cls_name = str(dets[i][-1]) | ||
bbox = dets[i][:4] | ||
score = dets[i][-2] | ||
text = "%s : %f" % (cls_name, score) | ||
cv2.rectangle(img_detections, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), (0, 0, 255), 8) | ||
cv2.putText(img_detections, text, (int(bbox[0]), int(bbox[1])), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2) | ||
return img_detections | ||
|
||
|
||
class DummyCallable(CallableBase): | ||
|
||
@record_kwargs | ||
def __init__(self, dummy_input='dummy_input_value'): | ||
super(DummyCallable, self).__init__() | ||
|
||
def __call__(self, image, debug=False): | ||
return {'dummy_key': 'dummy_value'} | ||
|
||
|
||
class FasterRCNNOpenCVCallable(CallableBase): | ||
|
||
@record_kwargs | ||
def __init__(self, proto_path, model_path, labels=None, conf_threshold=0.8): | ||
# For default parameter settings, | ||
# see: | ||
# https://github.com/rbgirshick/fast-rcnn/blob/b612190f279da3c11dd8b1396dd5e72779f8e463/lib/fast_rcnn/config.py | ||
super(FasterRCNNOpenCVCallable, self).__init__() | ||
self._scale = 600 | ||
self._max_size = 1000 | ||
# Pixel mean values (BGR order) as a (1, 1, 3) array | ||
# We use the same pixel mean for all networks even though it's not exactly what | ||
# they were trained with | ||
self._pixel_means = [102.9801, 115.9465, 122.7717] | ||
self._nms_threshold = 0.3 | ||
self._labels = labels | ||
self._net = cv2.dnn.readNetFromCaffe(proto_path, model_path) | ||
self._conf_threshold = conf_threshold | ||
logger.debug( | ||
'Created a FasterRCNNOpenCVProcessor:\nDNN proto definition is at {}\n' | ||
'model weight is at {}\nlabels are {}\nconf_threshold is {}'.format( | ||
proto_path, model_path, self._labels, self._conf_threshold)) | ||
|
||
@classmethod | ||
def from_json(cls, json_obj): | ||
try: | ||
kwargs = copy.copy(json_obj) | ||
kwargs['labels'] = json_obj['labels'] | ||
kwargs['_conf_threshold'] = float(json_obj['conf_threshold']) | ||
except ValueError as e: | ||
raise ValueError( | ||
'Failed to convert json object to {} instance. ' | ||
'The input json object is {}. ({})'.format(cls.__name__, | ||
json_obj, e)) | ||
return cls(**json_obj) | ||
|
||
def _getOutputsNames(self, net): | ||
layersNames = net.getLayerNames() | ||
return [layersNames[i[0] - 1] for i in net.getUnconnectedOutLayers()] | ||
|
||
def __call__(self, image): | ||
height, width = image.shape[:2] | ||
|
||
# resize image to correct size | ||
im_size_min = np.min(image.shape[0:2]) | ||
im_size_max = np.max(image.shape[0:2]) | ||
im_scale = float(self._scale) / float(im_size_min) | ||
# Prevent the biggest axis from being more than MAX_SIZE | ||
if np.round(im_scale * im_size_max) > self._max_size: | ||
im_scale = float(self._max_size) / float(im_size_max) | ||
im = cv2.resize(image, None, None, fx=im_scale, fy=im_scale, | ||
interpolation=cv2.INTER_LINEAR) | ||
# create input data | ||
blob = cv2.dnn.blobFromImage(im, 1, (width, height), self._pixel_means, | ||
swapRB=False, crop=False) | ||
imInfo = np.array([height, width, im_scale], dtype=np.float32) | ||
self._net.setInput(blob, 'data') | ||
self._net.setInput(imInfo, 'im_info') | ||
|
||
# infer | ||
outs = self._net.forward(self._getOutputsNames(self._net)) | ||
t, _ = self._net.getPerfProfile() | ||
logger.debug('Inference time: %.2f ms' % (t * 1000.0 / cv2.getTickFrequency())) | ||
|
||
# postprocess | ||
classIds = [] | ||
confidences = [] | ||
boxes = [] | ||
for out in outs: | ||
for detection in out[0, 0]: | ||
confidence = detection[2] | ||
if confidence > self._conf_threshold: | ||
left = int(detection[3]) | ||
top = int(detection[4]) | ||
right = int(detection[5]) | ||
bottom = int(detection[6]) | ||
width = right - left + 1 | ||
height = bottom - top + 1 | ||
classIds.append(int(detection[1]) - 1) # Skip background label | ||
confidences.append(float(confidence)) | ||
boxes.append([left, top, width, height]) | ||
|
||
indices = cv2.dnn.NMSBoxes(boxes, confidences, self._conf_threshold, self._nms_threshold) | ||
results = {} | ||
for i in indices: | ||
i = i[0] | ||
box = boxes[i] | ||
left = box[0] | ||
top = box[1] | ||
width = box[2] | ||
height = box[3] | ||
classId = int(classIds[i]) | ||
confidence = confidences[i] | ||
if self._labels[classId] not in results: | ||
results[self._labels[classId]] = [] | ||
results[self._labels[classId]].append([left, top, left+width, top+height, confidence, classId]) | ||
|
||
logger.debug('results: {}'.format(results)) | ||
return results |
Oops, something went wrong.