Skip to content
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
342 lines (262 sloc) 12.8 KB
# Log detections
GST_DEBUG=python:4 gst-launch-1.0 filesrc location=video.mp4 ! decodebin ! videoconvert ! video/x-raw,format=RGB ! gst_tf_detection config=cfg.yml ! videoconvert ! gtksink sync=False
# Draw detections
gst-launch-1.0 filesrc location=video.mp4 ! decodebin ! videoconvert ! video/x-raw,format=RGB ! gst_tf_detection config=cfg.yml ! videoconvert ! gst_detection_overlay ! videoconvert ! gtksink sync=False
import os
import logging
import timeit
import traceback
import time
import cv2
import tensorflow as tf
from typing import List, Tuple
import yaml
import numpy as np
from pygst_utils import gst_buffer_with_pad_to_ndarray, Gst, GObject
from pygst_utils.gst_objects_info_meta import gst_meta_write
def is_gpu(device: str) -> bool:
return "gpu" in device.lower()
def create_config(device: str = '/device:CPU:0',
per_process_gpu_memory_fraction: float = 0.0,
log_device_placement: bool = False) -> tf.ConfigProto:
if is_gpu(device):
config = tf.ConfigProto(log_device_placement=log_device_placement)
if per_process_gpu_memory_fraction > 0.0:
config.gpu_options.per_process_gpu_memory_fraction = per_process_gpu_memory_fraction
config.gpu_options.allow_growth = True
config = tf.ConfigProto(
log_device_placement=log_device_placement, device_count={'GPU': 0})
return config
def parse_graph_def(model_path: str) -> tf.GraphDef:
model_path = os.path.abspath(model_path)
assert os.path.isfile(model_path), "Invalid filename {}".format(model_path)
with tf.gfile.GFile(model_path, 'rb') as f:
graph_def = tf.GraphDef()
return graph_def
def import_graph(graph_def: tf.GraphDef, device: str, name: str="") -> tf.Graph:
with tf.device(device):
graph = tf.Graph()
with graph.as_default():
tf.import_graph_def(graph_def, name=name)
return graph
def load_labels_from_file(filename: str) -> dict:
assert os.path.isfile(filename), "Invalid filename {}".format(filename)
labels = {}
with open(filename, 'r') as f:
for line in f:
items = line.strip().split(":")
# print(items)
label_id, label_name = items[:2]
labels[int(label_id)] = label_name[1:]
except Exception as e:
print(e, items)
return labels
def load_config(filename: str) -> dict:
filename = os.path.abspath(filename)
assert os.path.isfile(filename), "Invalid filename {}".format(filename)
with open(filename, 'r') as stream:
data = yaml.load(stream, Loader=yaml.Loader)
return data
except yaml.YAMLError as exc:
raise OSError('Parsing error. Filename: {}'.format(filename))
class TfObjectDetectionModel(object):
def __init__(self, weights: str,
threshold: float = 0.5,
device: str = '/device:CPU:0',
per_process_gpu_memory_fraction: float = 0.0,
labels: List[str] = None,
input_shape: Tuple[int, int]=(300, 300)):
# TODO Docs
graph_def = parse_graph_def(weights)
config = create_config(device,
graph = import_graph(graph_def, device)
print(f"Model {weights} placed on {device}")
self.session = tf.Session(graph=graph, config=config)
# Taken from official website
self.input = graph.get_tensor_by_name("image_tensor:0")
self.input_shape = input_shape or (300, 300)
self.input_shape = tuple(self.input_shape)
# print([ for n in graph.as_graph_def().node][:10])
# print("Shape : ", self.input.shape)
# Taken from official website
output_names = ["detection_classes:0",
self.output = [graph.get_tensor_by_name(name) for name in output_names]
self.threshold = threshold
self.labels = labels or {}
self._box_scaler = None
def process_single(self, image: np.ndarray) -> List[dict]:
return self._process(np.expand_dims(self._preprocess(image), 0), image.shape[:2][::-1])[0]
def process_batch(self, images: List[np.ndarray]) -> List[dict]:
images_ = np.stack([self._preprocess(image) for image in images])
return self._process(images_, images[0].shape[:2][::-1])
def _process(self, images: np.ndarray, initial_shape: Tuple[int, int]) -> List[dict]:
classes, boxes, scores =,
feed_dict={self.input: images})
# _, h, w = images.shape[:3]
w, h = initial_shape
box_scaler = np.array([h, w, h, w])
num_detections = len(classes)
objects = [[] for _ in range(num_detections)]
for i in range(num_detections):
for class_id, box, score in zip(classes[i], boxes[i], scores[i]):
if class_id not in self.labels or \
score < self.threshold:
ymin, xmin, ymax, xmax = list(map(lambda x: int(x), box * box_scaler))
object_info = {'confidence': float(score),
'bounding_box': [xmin, ymin, xmax - xmin, ymax - ymin],
'class_name': self.labels[class_id]}
return objects
def _preprocess(self, image: np.ndarray) -> np.ndarray:
return cv2.resize(image, self.input_shape, interpolation=cv2.INTER_NEAREST)
def __del__(self):
""" Releases model when object deleted """
if self.session is not None:
def tf_object_detection_model_from_file(filename: str) -> TfObjectDetectionModel:
:param filename: filename to model config
return tf_object_detection_model_from_config(load_config(filename))
def tf_object_detection_model_from_config(config: dict) -> TfObjectDetectionModel:
:param config: model config
labels = load_labels_from_file(config['labels'])
return TfObjectDetectionModel(weights=config['weights'],
threshold=config.get('threshold', 0.5),
device=config.get('device', "/device:CPU:0"),
per_process_gpu_memory_fraction=config.get('per_process_gpu_memory_fraction', 0.0),
log_device_placement=config.get("log_device_placemenent", False),
input_shape=config.get('input_shape', (300, 300)))
class GstTfDetectionPluginPy(Gst.Element):
# Metadata Explanation:
GST_PLUGIN_NAME = 'gst_tf_detection'
__gstmetadata__ = ("Name",
_srctemplate ='src', Gst.PadDirection.SRC,
_sinktemplate ='sink', Gst.PadDirection.SINK,
__gsttemplates__ = (_srctemplate, _sinktemplate)
# Explanation:
# Example:
__gproperties__ = {
"model": (GObject.TYPE_PYOBJECT,
"Contains model TfObjectDetectionModel",
"config": (str,
"Path to config file",
"Contains path to config *.yml supported by TfObjectDetectionModel",
None, # default
def __init__(self):
super(GstTfDetectionPluginPy, self).__init__()
# Explained:
# Explanation how to init Pads
self.sinkpad = Gst.Pad.new_from_template(self._sinktemplate, 'sink')
# Set chain function
self.sinkpad.set_chain_function_full(self.chainfunc, None)
# Set event function
self.sinkpad.set_event_function_full(self.eventfunc, None)
self.srcpad = Gst.Pad.new_from_template(self._srctemplate, 'src')
# Set event function
self.srcpad.set_event_function_full(self.srceventfunc, None)
# Set query function
self.srcpad.set_query_function_full(self.srcqueryfunc, None)
self.model = None
self.config = None
self._channels = 3 # RGB -> 3 channels
def chainfunc(self, pad: Gst.Pad, parent, buffer: Gst.Buffer) -> Gst.FlowReturn:
:param parent: GstDetectionOverlay
if self.model is None:
return self.srcpad.push(buffer)
# Convert Gst.Buffer to np.ndarray
image = gst_buffer_with_pad_to_ndarray(buffer, pad, self._channels)
# model inference
objects = self.model.process_single(image)
# write objects to as Gst.Buffer's metadata
# Explained:
gst_meta_write(buffer, objects)
except Exception as e:
return Gst.FlowReturn.ERROR
return self.srcpad.push(buffer)
def do_get_property(self, prop: GObject.GParamSpec):
if == 'model':
return self.model
if == 'config':
return self.config
raise AttributeError('unknown property %s' %
def do_set_property(self, prop: GObject.GParamSpec, value):
if == 'model':
self.model = value
elif == "config":
self.model = tf_object_detection_model_from_file(value)
self.config = value
raise AttributeError('unknown property %s' %
def eventfunc(self, pad: Gst.Pad, parent, event: Gst.Event) -> bool:
""" Forwards event to SRC (DOWNSTREAM)
:param parent: GstTfDetectionPluginPy
return self.srcpad.push_event(event)
def srcqueryfunc(self, pad: Gst.Pad, parent, query: Gst.Query) -> bool:
""" Forwards query bacj to SINK (UPSTREAM)
:param parent: GstTfDetectionPluginPy
return self.sinkpad.query(query)
def srceventfunc(self, pad: Gst.Pad, parent, event: Gst.Event) -> bool:
""" Forwards event back to SINK (UPSTREAM)
:param parent: GstTfDetectionPluginPy
return self.sinkpad.push_event(event)
# Required for registering plugin dynamically
# Explained:
__gstelementfactory__ = (GstTfDetectionPluginPy.GST_PLUGIN_NAME,
Gst.Rank.NONE, GstTfDetectionPluginPy)
You can’t perform that action at this time.