Skip to content

Commit

Permalink
Object detection 2d yolov5 (#360)
Browse files Browse the repository at this point in the history
* DetectionDataset evaluation fix for empty detections

* yolov5 initial commit

* yolov5 learner cleanup

* yolov5 documentation fix

* yolov5 demo fix

* Update docs/reference/object-detection-2d-yolov5.md

Co-authored-by: Nikolaos Passalis <passalis@users.noreply.github.com>

* added force_reload as a parameter for the user to avoid redownloading the model every time

* added image size as inference parameter

* ROS1 docs update

* pep8 fixes

* pep8 fixes

* added tool to .yml files for testing

* fix dependencies

* Minor improvements on inference demo

* Simple YOLOv5 webcam demo

* Minor fix for deprecation warning

* Added 'opendr' in node name

* Added webcam demo reference in yolov5 readme list

* Update docs/reference/object-detection-2d-yolov5.md

Co-authored-by: ad-daniel <44834743+ad-daniel@users.noreply.github.com>

* Update docs/reference/object-detection-2d-yolov5.md

Co-authored-by: ad-daniel <44834743+ad-daniel@users.noreply.github.com>

* Update docs/reference/object-detection-2d-yolov5.md

Co-authored-by: ad-daniel <44834743+ad-daniel@users.noreply.github.com>

* Update docs/reference/object-detection-2d-yolov5.md

Co-authored-by: ad-daniel <44834743+ad-daniel@users.noreply.github.com>

* Update docs/reference/object-detection-2d-yolov5.md

Co-authored-by: Kostas Tsampazis <27914645+tsampazk@users.noreply.github.com>

* Update docs/reference/object-detection-2d-yolov5.md

Co-authored-by: Kostas Tsampazis <27914645+tsampazk@users.noreply.github.com>

* Update projects/python/perception/object_detection_2d/yolov5/inference_tutorial.ipynb

Co-authored-by: Kostas Tsampazis <27914645+tsampazk@users.noreply.github.com>

* Update projects/python/perception/object_detection_2d/yolov5/inference_tutorial.ipynb

Co-authored-by: Kostas Tsampazis <27914645+tsampazk@users.noreply.github.com>

* index + changelog + notebook fixes

* Changelog fix

Co-authored-by: Nikolaos Passalis <passalis@users.noreply.github.com>
Co-authored-by: Kostas Tsampazis <27914645+tsampazk@users.noreply.github.com>
Co-authored-by: ad-daniel <44834743+ad-daniel@users.noreply.github.com>
Co-authored-by: ad-daniel <daniel.dias@epfl.ch>
  • Loading branch information
5 people committed Nov 30, 2022
1 parent cf99c7b commit fad5d29
Show file tree
Hide file tree
Showing 21 changed files with 860 additions and 4 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/test_packages.yml
Expand Up @@ -41,6 +41,7 @@ jobs:
- perception/object_detection_2d/ssd
- perception/object_detection_2d/nanodet
- perception/object_detection_2d/yolov3
- perception/object_detection_2d/yolov5
- perception/object_detection_2d/retinaface
- perception/object_detection_2d/nms
- perception/facial_expression_recognition
Expand Down Expand Up @@ -90,6 +91,7 @@ jobs:
- perception/object_detection_2d/ssd
- perception/object_detection_2d/nanodet
- perception/object_detection_2d/yolov3
- perception/object_detection_2d/yolov5
- perception/object_detection_2d/retinaface
- perception/object_detection_2d/nms
- perception/facial_expression_recognition
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/tests_suite.yml
Expand Up @@ -74,6 +74,7 @@ jobs:
- perception/object_detection_2d/ssd
- perception/object_detection_2d/nanodet
- perception/object_detection_2d/yolov3
- perception/object_detection_2d/yolov5
- perception/object_detection_2d/retinaface
- perception/object_detection_2d/nms
- simulation/human_model_generation
Expand Down Expand Up @@ -182,6 +183,7 @@ jobs:
- perception/object_detection_2d/ssd
- perception/object_detection_2d/nanodet
- perception/object_detection_2d/yolov3
- perception/object_detection_2d/yolov5
- perception/object_detection_2d/retinaface
- perception/object_detection_2d/nms
- perception/facial_expression_recognition
Expand Down Expand Up @@ -255,6 +257,7 @@ jobs:
- perception/object_detection_2d/ssd
- perception/object_detection_2d/nanodet
- perception/object_detection_2d/yolov3
- perception/object_detection_2d/yolov5
- perception/object_detection_2d/retinaface
- perception/object_detection_2d/nms
- perception/facial_expression_recognition
Expand Down Expand Up @@ -334,6 +337,7 @@ jobs:
- perception/object_detection_2d/ssd
- perception/object_detection_2d/nanodet
- perception/object_detection_2d/yolov3
- perception/object_detection_2d/yolov5
- perception/object_detection_2d/retinaface
- perception/object_detection_2d/nms
- perception/facial_expression_recognition
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/tests_suite_develop.yml
Expand Up @@ -75,6 +75,7 @@ jobs:
- perception/object_detection_2d/ssd
- perception/object_detection_2d/nanodet
- perception/object_detection_2d/yolov3
- perception/object_detection_2d/yolov5
- perception/object_detection_2d/retinaface
- perception/object_detection_2d/nms
- simulation/human_model_generation
Expand Down Expand Up @@ -186,6 +187,7 @@ jobs:
- perception/object_detection_2d/ssd
- perception/object_detection_2d/nanodet
- perception/object_detection_2d/yolov3
- perception/object_detection_2d/yolov5
- perception/object_detection_2d/retinaface
- perception/object_detection_2d/nms
- perception/facial_expression_recognition
Expand Down Expand Up @@ -260,6 +262,7 @@ jobs:
- perception/object_detection_2d/ssd
- perception/object_detection_2d/nanodet
- perception/object_detection_2d/yolov3
- perception/object_detection_2d/yolov5
- perception/object_detection_2d/retinaface
- perception/object_detection_2d/nms
- perception/facial_expression_recognition
Expand Down Expand Up @@ -339,6 +342,7 @@ jobs:
- perception/object_detection_2d/ssd
- perception/object_detection_2d/nanodet
- perception/object_detection_2d/yolov3
- perception/object_detection_2d/yolov5
- perception/object_detection_2d/retinaface
- perception/object_detection_2d/nms
- perception/facial_expression_recognition
Expand Down
6 changes: 6 additions & 0 deletions CHANGELOG.md
@@ -1,5 +1,11 @@
# OpenDR Toolkit Change Log

## Version 2.0.0
Released on December, XX, 2022.

- New Features:
- Added YOLOv5 as an inference-only tool ([#360](https://github.com/opendr-eu/opendr/pull/360)).

## Version 1.1.1
Released on June, 30th, 2022.

Expand Down
1 change: 1 addition & 0 deletions docs/reference/index.md
Expand Up @@ -46,6 +46,7 @@ Neither the copyright holder nor any applicable licensor will be liable for any
- [centernet Module](object-detection-2d-centernet.md)
- [ssd Module](object-detection-2d-ssd.md)
- [yolov3 Module](object-detection-2d-yolov3.md)
- [yolov5 Module](object-detection-2d-yolov5.md)
- [seq2seq-nms Module](object-detection-2d-nms-seq2seq_nms.md)
- object detection 3d:
- [voxel Module](voxel-object-detection-3d.md)
Expand Down
81 changes: 81 additions & 0 deletions docs/reference/object-detection-2d-yolov5.md
@@ -0,0 +1,81 @@
## YOLOv5DetectorLearner module

The *yolov5* module contains the *YOLOv5DetectorLearner* class, which inherits from the abstract class *Learner*.

### Class YOLOv5DetectorLearner
Bases: `engine.learners.Learner`

The *YOLOv5DetectorLearner* class is a wrapper of the YOLO detector[[1]](#yolo-1)
[Ultralytics implementation](https://github.com/ultralytics/yolov5) based on its availability in the [Pytorch Hub](https://pytorch.org/hub/ultralytics_yolov5/).
It can be used to perform object detection on images (inference only).

The [YOLOv5DetectorLearner](/src/opendr/perception/object_detection_2d/yolov5/yolov5_learner.py) class has the following
public methods:

#### `YOLOv5DetectorLearner` constructor
```python
YOLOv5DetectorLearner(self, model_name, path, device)
```

Constructor parameters:

- **model_name**: *str*\
Specifies the name of the model to be used. Available models:
- 'yolov5n' (46.0% mAP, 1.9M parameters)
- 'yolov5s' (56.0% mAP, 7.2M parameters)
- 'yolov5m' (63.9% mAP, 21.2M parameters)
- 'yolov5l' (67.2% mAP, 46.5M parameters)
- 'yolov5x' (68.9% mAP, 86.7M parameters)
- 'yolov5n6' (50.7% mAP, 3.2M parameters)
- 'yolov5s6' (63.0% mAP, 16.8M parameters)
- 'yolov5m6' (69.0% mAP, 35.7 parameters)
- 'yolov5l6' (71.6% mAP, 76.8M parameters)
- 'custom' (for custom models, the ```path``` parameter must be set to point to the location of the weights file.)
Note that mAP (0.5) is reported on the [COCO val2017 dataset](https://github.com/ultralytics/yolov5/releases).
- **path**: *str, default=None*\
For custom-trained models, specifies the path to the weights to be loaded.
- **device**: *{'cuda', 'cpu'}, default='cuda'*
Specifies the device used for inference.
- **temp_path**: *str, default='.'*\
Specifies the path to where the weights will be downloaded when using pretrained models.
- **force_reload**: *bool, default=False*\
Sets the `force_reload` parameter of the pytorch hub `load` method.
This fixes issues with caching when set to `True`.


#### `YOLOv5DetectorLearner.infer`
The `infer` method:
```python
YOLOv5DetectorLearner.infer(self, img)
```

Performs inference on a single image.

Parameters:

- **img**: *object*\
Object of type engine.data.Image or OpenCV.
- **size**: *int, default=640*\
Size of image for inference.
The image is resized to this in both sides before being fed to the model.

#### Examples

* Inference and result drawing example on a test .jpg image using OpenCV:
```python
import torch
from opendr.engine.data import Image
from opendr.perception.object_detection_2d import YOLOv5DetectorLearner
from opendr.perception.object_detection_2d import draw_bounding_boxes

yolo = YOLOv5DetectorLearner(model_name='yolov5s', device='cpu')

torch.hub.download_url_to_file('https://ultralytics.com/images/zidane.jpg', 'zidane.jpg') # download image
im1 = Image.open('zidane.jpg') # OpenDR image

results = yolo.infer(im1)
draw_bounding_boxes(im1.opencv(), results, yolo.classes, show=True, line_thickness=3)
```

#### References
<a name="yolo-1" href="https://ultralytics.com/yolov5">[1]</a> YOLOv5: The friendliest AI architecture you'll ever use.
1 change: 1 addition & 0 deletions projects/opendr_ws/src/perception/CMakeLists.txt
Expand Up @@ -31,6 +31,7 @@ catkin_install_python(PROGRAMS
scripts/pose_estimation.py
scripts/fall_detection.py
scripts/object_detection_2d_nanodet.py
scripts/object_detection_2d_yolov5.py
scripts/object_detection_2d_detr.py
scripts/object_detection_2d_gem.py
scripts/semantic_segmentation_bisenet.py
Expand Down
8 changes: 6 additions & 2 deletions projects/opendr_ws/src/perception/README.md
Expand Up @@ -101,7 +101,7 @@ Reference images should be placed in a defined structure like:
under `/opendr/face_recognition_id`.

## 2D Object Detection ROS Nodes
ROS nodes are implemented for the SSD, YOLOv3, CenterNet, DETR and Nanodet generic object detectors.
ROS nodes are implemented for the SSD, YOLOv3, CenterNet, DETR, Nanodet and YOLOv5 generic object detectors.
Assuming that you have already [activated the OpenDR environment](../../../../docs/reference/installation.md), [built your workspace](../../README.md) and started roscore (i.e., just run `roscore`).

1. Start the node responsible for publishing images. If you have a USB camera, then you can use the corresponding node (assuming you have installed the corresponding package):
Expand All @@ -116,7 +116,7 @@ rosrun perception object_detection_2d_ssd.py
```
The annotated image stream can be viewed using `rqt_image_view`, and the default topic name is
`/opendr/image_boxes_annotated`. The bounding boxes alone are also published as `/opendr/objects`.
Similarly, the YOLOv3, CenterNet, DETR and Nanodet detector nodes can be run with:
Similarly, the YOLOv3, CenterNet, DETR, Nanodet and YOLOv5 detector nodes can be run with:
```shell
rosrun perception object_detection_2d_yolov3.py
```
Expand All @@ -132,6 +132,10 @@ or
```shell
rosrun perception object_detection_2d_nanodet.py
```
or
```shell
rosrun perception object_detection_2d_yolov5.py
```
respectively.

## Face Detection ROS Node
Expand Down
@@ -0,0 +1,139 @@
#!/usr/bin/env python
# Copyright 2020-2022 OpenDR European Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import mxnet as mx

import rospy
from vision_msgs.msg import Detection2DArray
from sensor_msgs.msg import Image as ROS_Image
from opendr_bridge import ROSBridge

from opendr.engine.data import Image
from opendr.perception.object_detection_2d import YOLOv5DetectorLearner
from opendr.perception.object_detection_2d import draw_bounding_boxes


class ObjectDetectionYOLONode:

def __init__(self, input_rgb_image_topic="/usb_cam/image_raw",
output_rgb_image_topic="/opendr/image_objects_annotated", detections_topic="/opendr/objects",
device="cuda", model_name="yolov5s"):
"""
Creates a ROS Node for object detection with YOLOV5.
:param input_rgb_image_topic: Topic from which we are reading the input image
:type input_rgb_image_topic: str
:param output_rgb_image_topic: Topic to which we are publishing the annotated image (if None, no annotated
image is published)
:type output_rgb_image_topic: str
:param detections_topic: Topic to which we are publishing the annotations (if None, no object detection message
is published)
:type detections_topic: str
:param device: device on which we are running inference ('cpu' or 'cuda')
:type device: str
:param model_name: network architecture name
:type model_name: str
"""
self.input_rgb_image_topic = input_rgb_image_topic

if output_rgb_image_topic is not None:
self.image_publisher = rospy.Publisher(output_rgb_image_topic, ROS_Image, queue_size=1)
else:
self.image_publisher = None

if detections_topic is not None:
self.object_publisher = rospy.Publisher(detections_topic, Detection2DArray, queue_size=1)
else:
self.object_publisher = None

self.bridge = ROSBridge()

# Initialize the object detector
self.object_detector = YOLOv5DetectorLearner(model_name=model_name, device=device)

def listen(self):
"""
Start the node and begin processing input data.
"""
rospy.init_node('opendr_object_detection_yolov5_node', anonymous=True)
rospy.Subscriber(self.input_rgb_image_topic, ROS_Image, self.callback, queue_size=1, buff_size=10000000)
rospy.loginfo("Object detection YOLOV5 node started.")
rospy.spin()

def callback(self, data):
"""
Callback that processes the input data and publishes to the corresponding topics.
:param data: input message
:type data: sensor_msgs.msg.Image
"""
# Convert sensor_msgs.msg.Image into OpenDR Image
image = self.bridge.from_ros_image(data, encoding='bgr8')

# Run object detection
boxes = self.object_detector.infer(image)

# Publish detections in ROS message
ros_boxes = self.bridge.to_ros_bounding_box_list(boxes) # Convert to ROS bounding_box_list
if self.object_publisher is not None:
self.object_publisher.publish(ros_boxes)

if self.image_publisher is not None:
# Get an OpenCV image back
image = image.opencv()
# Annotate image with object detection boxes
image = draw_bounding_boxes(image, boxes, class_names=self.object_detector.classes)
# Convert the annotated OpenDR image to ROS2 image message using bridge and publish it
self.image_publisher.publish(self.bridge.to_ros_image(Image(image), encoding='bgr8'))


def main():
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input_rgb_image_topic", help="Topic name for input rgb image",
type=str, default="/usb_cam/image_raw")
parser.add_argument("-o", "--output_rgb_image_topic", help="Topic name for output annotated rgb image",
type=lambda value: value if value.lower() != "none" else None,
default="/opendr/image_objects_annotated")
parser.add_argument("-d", "--detections_topic", help="Topic name for detection messages",
type=lambda value: value if value.lower() != "none" else None,
default="/opendr/objects")
parser.add_argument("--device", help="Device to use, either \"cpu\" or \"cuda\", defaults to \"cuda\"",
type=str, default="cuda", choices=["cuda", "cpu"])
parser.add_argument("--model_name", help="Network architecture, defaults to \"yolov5s\"",
type=str, default="yolov5s", choices=['yolov5s', 'yolov5n', 'yolov5m', 'yolov5l', 'yolov5x',
'yolov5n6', 'yolov5s6', 'yolov5m6', 'yolov5l6', 'custom'])
args = parser.parse_args()

try:
if args.device == "cuda" and mx.context.num_gpus() > 0:
device = "cuda"
elif args.device == "cuda":
print("GPU not found. Using CPU instead.")
device = "cpu"
else:
print("Using CPU.")
device = "cpu"
except:
print("Using CPU.")
device = "cpu"

object_detection_yolov5_node = ObjectDetectionYOLONode(device=device, model_name=args.model_name,
input_rgb_image_topic=args.input_rgb_image_topic,
output_rgb_image_topic=args.output_rgb_image_topic,
detections_topic=args.detections_topic)
object_detection_yolov5_node.listen()


if __name__ == '__main__':
main()
@@ -0,0 +1,7 @@
# YOLOv5DetectorLearner Demos

This folder contains minimal code usage examples that showcase the basic inference function of the YOLOv5DetectorLearner
provided by OpenDR. Specifically the following examples are provided:
1. inference_demo.py: Perform inference on a single image. Setting `--device cpu` performs inference on CPU.
2. webcam_demo.py: A simple tool that performs live object detection using a webcam.
3. inference_tutorial.ipynb: Perform inference using pretrained or custom models.

0 comments on commit fad5d29

Please sign in to comment.