Skip to content
This repository has been archived by the owner on Jul 2, 2021. It is now read-only.

Commit

Permalink
Merge pull request #453 from yuyu2172/coco_detection
Browse files Browse the repository at this point in the history
Add COCOBboxDataset
  • Loading branch information
Hakuyume committed Jun 5, 2018
2 parents 49d4254 + e64d00d commit dcbfcbb
Show file tree
Hide file tree
Showing 6 changed files with 357 additions and 0 deletions.
2 changes: 2 additions & 0 deletions chainercv/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from chainercv.datasets.cityscapes.cityscapes_test_image_dataset import CityscapesTestImageDataset # NOQA
from chainercv.datasets.cityscapes.cityscapes_utils import cityscapes_semantic_segmentation_label_colors # NOQA
from chainercv.datasets.cityscapes.cityscapes_utils import cityscapes_semantic_segmentation_label_names # NOQA
from chainercv.datasets.coco.coco_bbox_dataset import COCOBboxDataset # NOQA
from chainercv.datasets.coco.coco_utils import coco_bbox_label_names # NOQA
from chainercv.datasets.cub.cub_label_dataset import CUBLabelDataset # NOQA
from chainercv.datasets.cub.cub_point_dataset import CUBPointDataset # NOQA
from chainercv.datasets.cub.cub_utils import cub_label_names # NOQA
Expand Down
Empty file.
163 changes: 163 additions & 0 deletions chainercv/datasets/coco/coco_bbox_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
from collections import defaultdict
import json
import numpy as np
import os

from chainercv import utils

from chainercv.datasets.coco.coco_utils import get_coco

from chainercv.chainer_experimental.datasets.sliceable import GetterDataset


class COCOBboxDataset(GetterDataset):

"""Bounding box dataset for `MS COCO2014`_.
.. _`MS COCO2014`: http://mscoco.org/dataset/#detections-challenge2015
There are total of 82,783 training and 40,504 validation images.
'minval' split is a subset of validation images that constitutes
5,000 images in the validation images. The remaining validation
images are called 'minvalminus'. Concrete list of image ids and
annotations for these splits are found `here`_.
Args:
data_dir (string): Path to the root of the training data. If this is
:obj:`auto`, this class will automatically download data for you
under :obj:`$CHAINER_DATASET_ROOT/pfnet/chainercv/coco`.
split ({'train', 'val', 'minival', 'valminusminival'}): Select
a split of the dataset.
use_crowded (bool): If true, use bounding boxes that are labeled as
crowded in the original annotation. The default value is
:obj:`False`.
return_area (bool): If true, this dataset returns areas of masks
around objects. The default value is :obj:`False`.
return_crowded (bool): If true, this dataset returns a boolean array
that indicates whether bounding boxes are labeled as crowded
or not. The default value is :obj:`False`.
.. _`here`: https://github.com/rbgirshick/py-faster-rcnn/tree/master/data
This dataset returns the following data.
.. csv-table::
:header: name, shape, dtype, format
:obj:`img`, ":math:`(3, H, W)`", :obj:`float32`, \
"RGB, :math:`[0, 255]`"
:obj:`bbox` [#coco_bbox_1]_, ":math:`(R, 4)`", :obj:`float32`, \
":math:`(y_{min}, x_{min}, y_{max}, x_{max})`"
:obj:`label` [#coco_bbox_1]_, ":math:`(R,)`", :obj:`int32`, \
":math:`[0, #fg\_class - 1]`"
:obj:`area` [#coco_bbox_1]_ [#coco_bbox_2]_, ":math:`(R,)`", \
:obj:`float32`, --
:obj:`crowded` [#coco_bbox_3]_, ":math:`(R,)`", :obj:`bool`, --
.. [#coco_bbox_1] If :obj:`use_crowded = True`, :obj:`bbox`, \
:obj:`label` and :obj:`area` contain crowded instances.
.. [#coco_bbox_2] :obj:`area` is available \
if :obj:`return_area = True`.
.. [#coco_bbox_3] :obj:`crowded` is available \
if :obj:`return_crowded = True`.
When there are more than ten objects from the same category,
bounding boxes correspond to crowd of instances instead of individual
instances. Please see more detail in the Fig. 12 (e) of the summary
paper [#]_.
.. [#] Tsung-Yi Lin, Michael Maire, Serge Belongie, Lubomir Bourdev, \
Ross Girshick, James Hays, Pietro Perona, Deva Ramanan, \
C. Lawrence Zitnick, Piotr Dollar.
`Microsoft COCO: Common Objects in Context \
<https://arxiv.org/abs/1405.0312>`_. arXiv 2014.
"""

def __init__(self, data_dir='auto', split='train',
use_crowded=False, return_area=False, return_crowded=False):
super(COCOBboxDataset, self).__init__()
self.use_crowded = use_crowded
if split in ['val', 'minival', 'valminusminival']:
img_split = 'val'
else:
img_split = 'train'
if data_dir == 'auto':
data_dir = get_coco(split, img_split)

self.img_root = os.path.join(
data_dir, 'images', '{}2014'.format(img_split))
anno_path = os.path.join(
data_dir, 'annotations', 'instances_{}2014.json'.format(split))

self.data_dir = data_dir
annos = json.load(open(anno_path, 'r'))

self.id_to_prop = {}
for prop in annos['images']:
self.id_to_prop[prop['id']] = prop
self.ids = sorted(list(self.id_to_prop.keys()))

self.cat_ids = [cat['id'] for cat in annos['categories']]

self.id_to_anno = defaultdict(list)
for anno in annos['annotations']:
self.id_to_anno[anno['image_id']].append(anno)

self.add_getter('img', self._get_image)
self.add_getter(['bbox', 'label', 'area', 'crowded'],
self._get_annotations)

keys = ('img', 'bbox', 'label')
if return_area:
keys += ('area',)
if return_crowded:
keys += ('crowded',)
self.keys = keys

def __len__(self):
return len(self.ids)

def _get_image(self, i):
img_path = os.path.join(
self.img_root, self.id_to_prop[self.ids[i]]['file_name'])
img = utils.read_image(img_path, dtype=np.float32, color=True)
return img

def _get_annotations(self, i):
# List[{'segmentation', 'area', 'iscrowd',
# 'image_id', 'bbox', 'category_id', 'id'}]
annotation = self.id_to_anno[self.ids[i]]
bbox = np.array([ann['bbox'] for ann in annotation],
dtype=np.float32)
if len(bbox) == 0:
bbox = np.zeros((0, 4), dtype=np.float32)
# (x, y, width, height) -> (x_min, y_min, x_max, y_max)
bbox[:, 2] = bbox[:, 0] + bbox[:, 2]
bbox[:, 3] = bbox[:, 1] + bbox[:, 3]
# (x_min, y_min, x_max, y_max) -> (y_min, x_min, y_max, x_max)
bbox = bbox[:, [1, 0, 3, 2]]

label = np.array([self.cat_ids.index(ann['category_id'])
for ann in annotation], dtype=np.int32)

area = np.array([ann['area']
for ann in annotation], dtype=np.float32)

crowded = np.array([ann['iscrowd']
for ann in annotation], dtype=np.bool)

# Remove invalid boxes
bbox_area = np.prod(bbox[:, 2:] - bbox[:, :2], axis=1)
keep_mask = np.logical_and(bbox[:, 0] <= bbox[:, 2],
bbox[:, 1] <= bbox[:, 3])
keep_mask = np.logical_and(keep_mask, bbox_area > 0)

if not self.use_crowded:
keep_mask = np.logical_and(keep_mask, np.logical_not(crowded))

bbox = bbox[keep_mask]
label = label[keep_mask]
area = area[keep_mask]
crowded = crowded[keep_mask]
return bbox, label, area, crowded
129 changes: 129 additions & 0 deletions chainercv/datasets/coco/coco_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import os

from chainer.dataset import download

from chainercv import utils


root = 'pfnet/chainercv/coco'
img_urls = {
'train': 'http://msvocds.blob.core.windows.net/coco2014/train2014.zip',
'val': 'http://msvocds.blob.core.windows.net/coco2014/val2014.zip'
}
anno_urls = {
'train': 'http://msvocds.blob.core.windows.net/annotations-1-0-3/'
'instances_train-val2014.zip',
'val': 'http://msvocds.blob.core.windows.net/annotations-1-0-3/'
'instances_train-val2014.zip',
'valminusminival': 'https://dl.dropboxusercontent.com/s/s3tw5zcg7395368/'
'instances_valminusminival2014.json.zip',
'minival': 'https://dl.dropboxusercontent.com/s/o43o90bna78omob/'
'instances_minival2014.json.zip'
}


def get_coco(split, img_split):
url = img_urls[img_split]
data_dir = download.get_dataset_directory(root)
img_root = os.path.join(data_dir, 'images')
created_img_root = os.path.join(img_root, '{}2014'.format(img_split))
annos_root = os.path.join(data_dir, 'annotations')
anno_path = os.path.join(annos_root, 'instances_{}2014.json'.format(split))
if not os.path.exists(created_img_root):
download_file_path = utils.cached_download(url)
ext = os.path.splitext(url)[1]
utils.extractall(download_file_path, img_root, ext)
if not os.path.exists(anno_path):
anno_url = anno_urls[split]
download_file_path = utils.cached_download(anno_url)
ext = os.path.splitext(anno_url)[1]
utils.extractall(download_file_path, annos_root, ext)
return data_dir


# How you can get the labels
# >>> from pycocotools.coco import COCO
# >>> coco = COCO('instances_train2014.json')
# >>> cat_dict = coco.loadCats(coco.getCatIds())
# >>> coco_bbox_label_names = [c['name'] for c in cat_dict]
coco_bbox_label_names = (
'person',
'bicycle',
'car',
'motorcycle',
'airplane',
'bus',
'train',
'truck',
'boat',
'traffic light',
'fire hydrant',
'stop sign',
'parking meter',
'bench',
'bird',
'cat',
'dog',
'horse',
'sheep',
'cow',
'elephant',
'bear',
'zebra',
'giraffe',
'backpack',
'umbrella',
'handbag',
'tie',
'suitcase',
'frisbee',
'skis',
'snowboard',
'sports ball',
'kite',
'baseball bat',
'baseball glove',
'skateboard',
'surfboard',
'tennis racket',
'bottle',
'wine glass',
'cup',
'fork',
'knife',
'spoon',
'bowl',
'banana',
'apple',
'sandwich',
'orange',
'broccoli',
'carrot',
'hot dog',
'pizza',
'donut',
'cake',
'chair',
'couch',
'potted plant',
'bed',
'dining table',
'toilet',
'tv',
'laptop',
'mouse',
'remote',
'keyboard',
'cell phone',
'microwave',
'oven',
'toaster',
'sink',
'refrigerator',
'book',
'clock',
'vase',
'scissors',
'teddy bear',
'hair drier',
'toothbrush')
9 changes: 9 additions & 0 deletions docs/source/reference/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ CityscapesTestImageDataset
~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: CityscapesTestImageDataset


CUB
---

Expand All @@ -63,6 +64,14 @@ CUBPointDataset
.. autoclass:: CUBPointDataset


MS COCO
-------

COCOBboxDataset
~~~~~~~~~~~~~~~
.. autoclass:: COCOBboxDataset


OnlineProducts
--------------

Expand Down
54 changes: 54 additions & 0 deletions tests/datasets_tests/coco_tests/test_coco_bbox_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import unittest

import numpy as np

from chainer import testing
from chainer.testing import attr

from chainercv.datasets import coco_bbox_label_names
from chainercv.datasets import COCOBboxDataset
from chainercv.utils import assert_is_bbox_dataset


@testing.parameterize(*testing.product({
'split': ['train', 'val', 'minival', 'valminusminival'],
'use_crowded': [False, True],
'return_area': [False, True],
'return_crowded': [False, True]
}))
class TestCOCOBboxDataset(unittest.TestCase):

def setUp(self):
self.dataset = COCOBboxDataset(
split=self.split,
use_crowded=self.use_crowded, return_area=self.return_area,
return_crowded=self.return_crowded)

@attr.slow
def test_coco_bbox_dataset(self):
assert_is_bbox_dataset(
self.dataset, len(coco_bbox_label_names), n_example=30)

if self.return_area:
for _ in range(10):
i = np.random.randint(0, len(self.dataset))
_, bbox, _, area = self.dataset[i][:4]
self.assertIsInstance(area, np.ndarray)
self.assertEqual(area.dtype, np.float32)
self.assertEqual(area.shape, (bbox.shape[0],))

if self.return_crowded:
for _ in range(10):
i = np.random.randint(0, len(self.dataset))
example = self.dataset[i]
crowded = example[-1]
bbox = example[1]
self.assertIsInstance(crowded, np.ndarray)
self.assertEqual(crowded.dtype, np.bool)
self.assertEqual(crowded.shape, (bbox.shape[0],))

if not self.use_crowded:
np.testing.assert_equal(crowded, 0)


testing.run_module(__name__, __file__)

0 comments on commit dcbfcbb

Please sign in to comment.