From 38190257938dd5c53601dd073786536589782be9 Mon Sep 17 00:00:00 2001 From: Vivi Date: Mon, 14 Feb 2022 14:34:19 +0200 Subject: [PATCH 1/6] xml type dataset addition --- .../object_detection_2d/datasets/__init__.py | 5 +- .../datasets/detection_dataset.py | 49 +++++++ .../datasets/xmldataset.py | 133 ++++++++++++++++++ 3 files changed, 186 insertions(+), 1 deletion(-) create mode 100644 src/opendr/perception/object_detection_2d/datasets/xmldataset.py diff --git a/src/opendr/perception/object_detection_2d/datasets/__init__.py b/src/opendr/perception/object_detection_2d/datasets/__init__.py index 2a8a03297b..d5c53a5c33 100644 --- a/src/opendr/perception/object_detection_2d/datasets/__init__.py +++ b/src/opendr/perception/object_detection_2d/datasets/__init__.py @@ -1,5 +1,8 @@ from .detection_dataset import DetectionDataset from .wider_face import WiderFaceDataset from .wider_person import WiderPersonDataset +from .xmldataset import XMLBasedDataset +from .detection_dataset import ConcatDataset -__all__ = ['DetectionDataset', 'WiderFaceDataset', 'WiderPersonDataset'] +__all__ = ['DetectionDataset', 'WiderFaceDataset', 'WiderPersonDataset', 'XMLBasedDataset', + 'ConcatDataset'] diff --git a/src/opendr/perception/object_detection_2d/datasets/detection_dataset.py b/src/opendr/perception/object_detection_2d/datasets/detection_dataset.py index 0889f67a70..8c62edbf44 100644 --- a/src/opendr/perception/object_detection_2d/datasets/detection_dataset.py +++ b/src/opendr/perception/object_detection_2d/datasets/detection_dataset.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import bisect +from itertools import accumulate + from opendr.engine.datasets import DatasetIterator @@ -68,3 +71,49 @@ def __getitem__(self, idx): if isinstance(item, tuple): return self.map_function(*item) return self.map_function(item) + + +class ConcatDataset(DetectionDataset): + def __init__(self, datasets): + super(ConcatDataset, self).__init__(classes=datasets[0].classes, dataset_type='concat_dataset', + root=None) + self.cumulative_lengths = list(accumulate([len(dataset) for dataset in datasets])) + self.datasets = datasets + + def set_transform(self, transform): + self._transform = transform + for dataset in self.datasets: + dataset.transform(transform) + + def transform(self, transform): + mapped_datasets = [MappedDetectionDataset(dataset, transform) for dataset in self.datasets] + return ConcatDataset(mapped_datasets) + + def set_image_transform(self, transform): + self._image_transform = transform + for dataset in self.datasets: + dataset.set_image_transform(transform) + + def set_target_transform(self, transform): + self._target_transform = transform + for dataset in self.datasets: + dataset.set_target_transform(transform) + + def __len__(self): + return self.cumulative_lengths[-1] + + def __getitem__(self, item): + dataset_idx = bisect.bisect_right(self.cumulative_lengths, item) + if dataset_idx == 0: + sample_idx = item + else: + sample_idx = item - self.cumulative_lengths[dataset_idx - 1] + return self.datasets[dataset_idx][sample_idx] + + +def is_image_type(filename): + return filename.lower().endswith(('png', 'jpg', 'jpeg', 'tiff', 'bmp', 'gif')) + + +def remove_extension(filename): + return '.'.join(filename.split('.')[:-1]) diff --git a/src/opendr/perception/object_detection_2d/datasets/xmldataset.py b/src/opendr/perception/object_detection_2d/datasets/xmldataset.py new file mode 100644 index 0000000000..312734eb2c --- /dev/null +++ b/src/opendr/perception/object_detection_2d/datasets/xmldataset.py @@ -0,0 +1,133 @@ +# Copyright 2020-2021 OpenDR European Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +try: + import xml.etree.cElementTree as ET +except ImportError: + import xml.etree.ElementTree as ET +import cv2 + +from opendr.engine.data import Image +from opendr.engine.target import BoundingBox, BoundingBoxList +from opendr.perception.object_detection_2d.datasets.detection_dataset import DetectionDataset, is_image_type, remove_extension + + +class XMLBasedDataset(DetectionDataset): + """ + Reader class for datasets annotated with the LabelImg tool in Pascal VOC XML format. + The dataset should be in the following structure: + data_root + |-- images + |-- annotations + The exact names of the folders can be passed as arguments (images_dir) and (annotations_dir). + """ + def __init__(self, dataset_type, root, classes=None, image_transform=None, + target_transform=None, transform=None, splits='', + images_dir='images', annotations_dir='annotations', preload_anno=False): + self.abs_images_dir = os.path.join(root, images_dir) + self.abs_annot_dir = os.path.join(root, annotations_dir) + image_names = [im_filename for im_filename in os.listdir(self.abs_images_dir) + if is_image_type(im_filename)] + + if classes is None: + classes = [] + self.classes = classes + super().__init__(classes, dataset_type, root, image_transform=image_transform, target_transform=target_transform, + transform=transform, image_paths=image_names, splits=splits) + self.bboxes = [] + self.preload_anno = preload_anno + if preload_anno: + for image_name in image_names: + annot_file = os.path.join(self.abs_annot_dir, remove_extension(image_name) + '.xml') + bboxes = self._read_annotation_file(annot_file) + self.bboxes.append(bboxes) + + def _read_annotation_file(self, filename): + root = ET.parse(filename).getroot() + bounding_boxes = [] + for obj in root.iter('object'): + cls_name = obj.find('name').text.strip().lower() + if cls_name not in self.classes: + self.classes.append(cls_name) + cls_id = self.classes.index(cls_name) + xml_box = obj.find('bndbox') + xmin = (float(xml_box.find('xmin').text) - 1) + ymin = (float(xml_box.find('ymin').text) - 1) + xmax = (float(xml_box.find('xmax').text) - 1) + ymax = (float(xml_box.find('ymax').text) - 1) + bounding_box = BoundingBox(name=int(cls_id), + left=float(xmin), top=float(ymin), + width=float(xmax) - float(xmin), + height=float(ymax) - float(ymin)) + bounding_boxes.append(bounding_box) + return BoundingBoxList(boxes=bounding_boxes) + + def __len__(self): + return len(self.image_paths) + + def __getitem__(self, item): + image_name = self.image_paths[item] + image_path = os.path.join(self.abs_images_dir, image_name) + img_np = cv2.imread(image_path) + img = Image(img_np) + + if self.preload_anno: + label = self.bboxes[item] + else: + annot_file = os.path.join(self.abs_annot_dir, remove_extension(image_name) + '.xml') + label = self._read_annotation_file(annot_file) + + if self._image_transform is not None: + img = self._image_transform(img) + + if self._target_transform is not None: + label = self._target_transform(label) + + if self._transform is not None: + return self._transform(img, label) + return img, label + + def get_image(self, item): + image_name = self.image_paths[item] + image_path = os.path.join(self.abs_images_dir, image_name) + img_np = cv2.imread(image_path) + if self._image_transform is not None: + img = self._image_transform(img_np) + return img + + def get_bboxes(self, item): + boxes = self.bboxes[item] + if self._target_transform is not None: + boxes = self._target_transform(boxes) + return boxes + + +if __name__ == '__main__': + # TODO: remove these after testing + from opendr.perception.object_detection_2d.utils.vis_utils import draw_bounding_boxes + + dataset = XMLBasedDataset(root='/home/administrator/data/agi_human_data', dataset_type='agi_human', + images_dir='no_human', annotations_dir='no_human_anot') + print(len(dataset)) + + all_boxes = [[[] for _ in range(len(dataset))] + for _ in range(dataset.num_classes)] + + for i, (img, targets) in enumerate(dataset): + img = draw_bounding_boxes(img.opencv(), targets, class_names=dataset.classes) + img = cv2.resize(img, (0, 0), fx=0.5, fy=0.5) + cv2.imshow('img', img) + cv2.waitKey(0) + cv2.destroyAllWindows() From cb7118577d4e52c4d60c0de37469a40ac64b0ac0 Mon Sep 17 00:00:00 2001 From: Vivi Date: Tue, 15 Feb 2022 15:35:50 +0200 Subject: [PATCH 2/6] documentation for concatdataset --- .../datasets/detection_dataset.py | 5 +++++ .../datasets/xmldataset.py | 19 ------------------- 2 files changed, 5 insertions(+), 19 deletions(-) diff --git a/src/opendr/perception/object_detection_2d/datasets/detection_dataset.py b/src/opendr/perception/object_detection_2d/datasets/detection_dataset.py index 8c62edbf44..c72f9c71c9 100644 --- a/src/opendr/perception/object_detection_2d/datasets/detection_dataset.py +++ b/src/opendr/perception/object_detection_2d/datasets/detection_dataset.py @@ -74,6 +74,11 @@ def __getitem__(self, idx): class ConcatDataset(DetectionDataset): + """ + Basic dataset concatenation class. The datasets are assumed to have the same classes. + + :param datasets: list of DetectionDataset type or subclass + """ def __init__(self, datasets): super(ConcatDataset, self).__init__(classes=datasets[0].classes, dataset_type='concat_dataset', root=None) diff --git a/src/opendr/perception/object_detection_2d/datasets/xmldataset.py b/src/opendr/perception/object_detection_2d/datasets/xmldataset.py index 312734eb2c..a8a16d66b1 100644 --- a/src/opendr/perception/object_detection_2d/datasets/xmldataset.py +++ b/src/opendr/perception/object_detection_2d/datasets/xmldataset.py @@ -112,22 +112,3 @@ def get_bboxes(self, item): if self._target_transform is not None: boxes = self._target_transform(boxes) return boxes - - -if __name__ == '__main__': - # TODO: remove these after testing - from opendr.perception.object_detection_2d.utils.vis_utils import draw_bounding_boxes - - dataset = XMLBasedDataset(root='/home/administrator/data/agi_human_data', dataset_type='agi_human', - images_dir='no_human', annotations_dir='no_human_anot') - print(len(dataset)) - - all_boxes = [[[] for _ in range(len(dataset))] - for _ in range(dataset.num_classes)] - - for i, (img, targets) in enumerate(dataset): - img = draw_bounding_boxes(img.opencv(), targets, class_names=dataset.classes) - img = cv2.resize(img, (0, 0), fx=0.5, fy=0.5) - cv2.imshow('img', img) - cv2.waitKey(0) - cv2.destroyAllWindows() From 4418aa246786127dce3dadae4582b1f24702f2ae Mon Sep 17 00:00:00 2001 From: Vivi Date: Tue, 15 Feb 2022 18:13:25 +0200 Subject: [PATCH 3/6] mxnet-cu102 changes --- src/opendr/perception/object_detection_2d/dependencies.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/opendr/perception/object_detection_2d/dependencies.ini b/src/opendr/perception/object_detection_2d/dependencies.ini index 0522648390..6c4ab1dfa1 100644 --- a/src/opendr/perception/object_detection_2d/dependencies.ini +++ b/src/opendr/perception/object_detection_2d/dependencies.ini @@ -1,7 +1,7 @@ [runtime] # 'python' key expects a value using the Python requirements file format # https://pip.pypa.io/en/stable/reference/pip_install/#requirements-file-format -python=mxnet==1.8.0 +python=mxnet-cu102==1.8.0 gluoncv==0.11.0b20210908 tqdm pycocotools>=2.0.4 From ab79c3d579900059959db1c7e22d19b57aad3972 Mon Sep 17 00:00:00 2001 From: Vivi Date: Tue, 15 Feb 2022 18:17:31 +0200 Subject: [PATCH 4/6] fixed license --- .../perception/object_detection_2d/datasets/xmldataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/opendr/perception/object_detection_2d/datasets/xmldataset.py b/src/opendr/perception/object_detection_2d/datasets/xmldataset.py index a8a16d66b1..746dba848b 100644 --- a/src/opendr/perception/object_detection_2d/datasets/xmldataset.py +++ b/src/opendr/perception/object_detection_2d/datasets/xmldataset.py @@ -1,4 +1,4 @@ -# Copyright 2020-2021 OpenDR European Project +# Copyright 2020-2022 OpenDR European Project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 537070e6eb8b54ca601856d4993b042a3fea6686 Mon Sep 17 00:00:00 2001 From: Vivi Date: Fri, 18 Feb 2022 18:11:20 +0200 Subject: [PATCH 5/6] changed order of mxnet cpu-gpu installation --- bin/install.sh | 11 +++++------ .../perception/object_detection_2d/dependencies.ini | 2 +- .../perception/object_detection_2d/ssd/ssd_learner.py | 1 + 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/bin/install.sh b/bin/install.sh index f8078bd0da..49da5a0be1 100755 --- a/bin/install.sh +++ b/bin/install.sh @@ -35,17 +35,16 @@ pip3 install setuptools configparser sudo sh -c 'echo "deb http://packages.ros.org/ros/ubuntu $(lsb_release -sc) main" > /etc/apt/sources.list.d/ros-latest.list' \ && curl -s https://raw.githubusercontent.com/ros/rosdistro/master/ros.asc | sudo apt-key add - +# Build OpenDR +make install_compilation_dependencies +make install_runtime_dependencies +make libopendr -# If working on GPU install GPU dependencies beforehand +# If working on GPU install GPU dependencies as needed if [[ "${OPENDR_DEVICE}" == "gpu" ]]; then echo "[INFO] Installing mxnet-cu102==1.8.0. You can override this later if you are using a different CUDA version." pip3 install mxnet-cu102==1.8.0 fi -# Build OpenDR -make install_compilation_dependencies -make install_runtime_dependencies -make libopendr - deactivate diff --git a/src/opendr/perception/object_detection_2d/dependencies.ini b/src/opendr/perception/object_detection_2d/dependencies.ini index 6c4ab1dfa1..0522648390 100644 --- a/src/opendr/perception/object_detection_2d/dependencies.ini +++ b/src/opendr/perception/object_detection_2d/dependencies.ini @@ -1,7 +1,7 @@ [runtime] # 'python' key expects a value using the Python requirements file format # https://pip.pypa.io/en/stable/reference/pip_install/#requirements-file-format -python=mxnet-cu102==1.8.0 +python=mxnet==1.8.0 gluoncv==0.11.0b20210908 tqdm pycocotools>=2.0.4 diff --git a/src/opendr/perception/object_detection_2d/ssd/ssd_learner.py b/src/opendr/perception/object_detection_2d/ssd/ssd_learner.py index 8bb8b8f65b..34dbdc5760 100644 --- a/src/opendr/perception/object_detection_2d/ssd/ssd_learner.py +++ b/src/opendr/perception/object_detection_2d/ssd/ssd_learner.py @@ -87,6 +87,7 @@ def __init__(self, lr=1e-3, epochs=120, batch_size=8, self.ctx = mx.gpu(0) else: self.ctx = mx.cpu() + print("Device set to cuda but no GPU available, using CPU...") else: self.ctx = mx.cpu() From 7e65f4029858e07d478e2ed4c20a8d349bc0f864 Mon Sep 17 00:00:00 2001 From: Vivi Date: Fri, 18 Feb 2022 20:15:34 +0200 Subject: [PATCH 6/6] documentation addition + fix --- .../object-detection-2d-centernet.md | 2 +- .../reference/object-detection-2d-datasets.md | 87 +++++++++++++++++++ 2 files changed, 88 insertions(+), 1 deletion(-) create mode 100644 docs/reference/object-detection-2d-datasets.md diff --git a/docs/reference/object-detection-2d-centernet.md b/docs/reference/object-detection-2d-centernet.md index d56f1e9406..72274aa839 100644 --- a/docs/reference/object-detection-2d-centernet.md +++ b/docs/reference/object-detection-2d-centernet.md @@ -5,7 +5,7 @@ The *centernet* module contains the *CenterNetDetectorLearner* class, which inhe ### Class CenterNetDetectorLearner Bases: `engine.learners.Learner` -The *CenterNetDetectorLearner* class is a wrapper of the SSD detector[[1]](#centernet-1) +The *CenterNetDetectorLearner* class is a wrapper of the CenterNet detector[[1]](#centernet-1) [GluonCV implementation](https://github.com/dmlc/gluon-cv/blob/master/gluoncv/model_zoo/center_net/center_net.py). It can be used to perform object detection on images (inference) as well as train new object detection models. diff --git a/docs/reference/object-detection-2d-datasets.md b/docs/reference/object-detection-2d-datasets.md new file mode 100644 index 0000000000..de0990e8c1 --- /dev/null +++ b/docs/reference/object-detection-2d-datasets.md @@ -0,0 +1,87 @@ +# Object Detection 2D Datasets + +## Base Classes + +### DetectionDataset base class + +Bases: `engine.datasets.DatasetIterator` + +The *DetectionDataset* class inherits from the *DatasetIterator* class and extends it with functions and properties aimed at 2d Object Detection datasets. Each *DetectionDataset* object must be initialized with the following parameters: + +- **classes**: *list*\ + List of class names of the training dataset. +- **dataset_type**: *str*\ + Dataset type, i`.e., an assigned name. +- **root**: *str*\ + Path to dataset root directory. +- **image_paths**: *list, default=None*\ + List of image data paths. +- **splits**: *str, default=''*\ + List of dataset splits to load (e.g., train/val). +- **image_transform**: *callable object, default=None*\ + Transformation to apply to images. Intended for image format transformations. +- **target_transform**: *callable object, default=None*\ + Transformation to apply to bounding boxes. Intended for formatting the bounding boxes for each detector. +- **transform**: *callable object, default=None*\ + Transformation to apply to both images and bounding boxes. Intended for data augmentation purposes. + +Methods: + +#### `DetectionDataset.set_transform` +Setter for the internal **transform** object/function. + +#### `DetectionDataset.set_image_transform` +Setter for the internal **image_transform** object/function. + +#### `DetectionDataset.set_target_transform` +Setter for the internal **target_transform** object/function. + +#### `DetectionDataset.transform` +Returns the `DetectionDataset` wrapped as a `MappedDetectionDataset`, where the data is transformed according to the argument callable object/function. This function ensures fit/eval compatibility between `DetectionDataset` and `ExternalDataset` for [GluonCV](https://github.com/dmlc/gluon-cv) based detectors. + +#### `DetectionDataset.get_image` +Returns an image from the dataset. Intended for test sets without annotations. + +#### `DetectionDataset.get_bboxes` +Returns the bounding boxes for a given sample. + + +### MappedDetectionDataset class + +Bases: `engine.datasets.DatasetIterator` + +This class wraps any `DetectionDataset` and applies `map_function` to the data. + +### ConcatDataset class + +Bases: `perception.object_detection_2d.datasets.DetetionDataset` + +Returns a new `DetectionDataset` which is a concatenation of the `datasets` param. The datasets are assumed to have the same classes. + +### XMLBasedDataset class + +Bases: `perception.object_detection_2d.datasets.DetetionDataset` + +This class is intended for any dataset in PASCAL VOC .xml format, making it compatible with datasets annotated using the [labelImg](https://github.com/tzutalin/labelImg) tool. Each *XMLBasedDataset* object must be initialized with the following parameters: + +- **dataset_type**: *str*\ + Dataset type, i.e., assigned name. +- **root**: *str*\ + Path to dataset root directory. +- **classes**: *list, default=None*\ + Class names. If None, they will be inferred from the annotations. +- **splits**: *str, default=''*\ + List of dataset splits to load (e.g., train/val). +- **image_transform**: *callable object, default=None*\ + Transformation to apply to images. Intended for image format transformations. +- **target_transform**: *callable object, default=None*\ + Transformation to apply to bounding boxes. Intended for formatting the bounding boxes for each detector. +- **transform**: *callable object, default=None*\ + Transformation to apply to both images and bounding boxes. Intended for data augmentation purposes. +- **images_dir**: *str, default='images'*\ + Name of subdirectory containing dataset images. +- **annotations_dir**: *str, default='annotations'*\ + Name of subdirectory containing dataset annotations. +- **preload_anno**: *bool, default=False*\ + Whether to preload annotations, for datasets that fit in memory. +