Skip to content
This repository has been archived by the owner on Jul 2, 2021. It is now read-only.

Add COCOBboxDataset #453

Merged
merged 39 commits into from
Jun 5, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
30be97c
add coco_detection_dataset
yuyu2172 May 24, 2017
6119d55
use bbox_dataset notation and add labels method
yuyu2172 Oct 13, 2017
2c27e42
merge master
yuyu2172 Oct 13, 2017
403fe5d
be explict about year
yuyu2172 Oct 13, 2017
7a95a3b
Merge branch 'assert-is-bbox-dataset-zero-sized' into coco_detection
yuyu2172 Oct 13, 2017
ba331d5
add test and modify to pass the test
yuyu2172 Oct 13, 2017
281fe55
fix doc
yuyu2172 Oct 13, 2017
27d44a9
flake8
yuyu2172 Oct 13, 2017
25402df
fix failing test
yuyu2172 Oct 13, 2017
c48937c
conduct sanitization inside _get_annotations
yuyu2172 Oct 13, 2017
2eb19d9
flake8
yuyu2172 Oct 13, 2017
4d842fe
delete unnecessary sanitization
yuyu2172 Oct 13, 2017
326eb6d
sort ids
yuyu2172 Oct 14, 2017
ba070d7
add return_area option
yuyu2172 Oct 14, 2017
5c11243
fix failing test
yuyu2172 Oct 14, 2017
682009f
merge master
yuyu2172 Apr 9, 2018
3c0af7b
Merge remote-tracking branch 'Hakuyume/annotated-dataset-mixin' into …
yuyu2172 Apr 9, 2018
d5dfb83
WIP
yuyu2172 Apr 9, 2018
3a671f6
Merge remote-tracking branch 'origin/master' into coco_detection
yuyu2172 Apr 18, 2018
53796f9
change the default return values and update doc
yuyu2172 Apr 18, 2018
f9b6dd3
fix doc
yuyu2172 Apr 18, 2018
2309c54
fix bug
yuyu2172 Apr 18, 2018
607f262
merge
yuyu2172 Apr 18, 2018
ae35e56
Merge remote-tracking branch 'origin/master' into coco_detection
yuyu2172 Apr 19, 2018
5b6a262
use tuple to set keys
yuyu2172 Apr 19, 2018
1749887
fix doc
yuyu2172 Apr 19, 2018
b4f53b5
clean up __init__
yuyu2172 Apr 19, 2018
a133d1b
Merge remote-tracking branch 'origin/master' into HEAD
yuyu2172 Apr 19, 2018
393a4c4
indentation csv
yuyu2172 Apr 19, 2018
1d09d4e
clean up
yuyu2172 Apr 19, 2018
7871347
fix order of data
yuyu2172 Jun 4, 2018
8d361fc
change variable names
yuyu2172 Jun 4, 2018
2025cd7
fix
yuyu2172 Jun 4, 2018
372cba2
Merge remote-tracking branch 'origin/master' into coco_detection
yuyu2172 Jun 4, 2018
4fd1777
fix order or argument
yuyu2172 Jun 5, 2018
81e0f94
fix var name
yuyu2172 Jun 5, 2018
53b5d97
fix var name
yuyu2172 Jun 5, 2018
63cfdf1
Merge remote-tracking branch 'origin/master' into coco_detection
yuyu2172 Jun 5, 2018
e64d00d
reflect comments
yuyu2172 Jun 5, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions chainercv/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from chainercv.datasets.cityscapes.cityscapes_test_image_dataset import CityscapesTestImageDataset # NOQA
from chainercv.datasets.cityscapes.cityscapes_utils import cityscapes_semantic_segmentation_label_colors # NOQA
from chainercv.datasets.cityscapes.cityscapes_utils import cityscapes_semantic_segmentation_label_names # NOQA
from chainercv.datasets.coco.coco_bbox_dataset import COCOBboxDataset # NOQA
from chainercv.datasets.coco.coco_utils import coco_bbox_label_names # NOQA
from chainercv.datasets.cub.cub_label_dataset import CUBLabelDataset # NOQA
from chainercv.datasets.cub.cub_point_dataset import CUBPointDataset # NOQA
from chainercv.datasets.cub.cub_utils import cub_label_names # NOQA
Expand Down
Empty file.
163 changes: 163 additions & 0 deletions chainercv/datasets/coco/coco_bbox_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
from collections import defaultdict
import json
import numpy as np
import os

from chainercv import utils

from chainercv.datasets.coco.coco_utils import get_coco

from chainercv.chainer_experimental.datasets.sliceable import GetterDataset


class COCOBboxDataset(GetterDataset):

"""Bounding box dataset for `MS COCO2014`_.

.. _`MS COCO2014`: http://mscoco.org/dataset/#detections-challenge2015

There are total of 82,783 training and 40,504 validation images.
'minval' split is a subset of validation images that constitutes
5,000 images in the validation images. The remaining validation
images are called 'minvalminus'. Concrete list of image ids and
annotations for these splits are found `here`_.

Args:
data_dir (string): Path to the root of the training data. If this is
:obj:`auto`, this class will automatically download data for you
under :obj:`$CHAINER_DATASET_ROOT/pfnet/chainercv/coco`.
split ({'train', 'val', 'minival', 'valminusminival'}): Select
a split of the dataset.
use_crowded (bool): If true, use bounding boxes that are labeled as
crowded in the original annotation. The default value is
:obj:`False`.
return_area (bool): If true, this dataset returns areas of masks
around objects. The default value is :obj:`False`.
return_crowded (bool): If true, this dataset returns a boolean array
that indicates whether bounding boxes are labeled as crowded
or not. The default value is :obj:`False`.

.. _`here`: https://github.com/rbgirshick/py-faster-rcnn/tree/master/data

This dataset returns the following data.

.. csv-table::
:header: name, shape, dtype, format

:obj:`img`, ":math:`(3, H, W)`", :obj:`float32`, \
"RGB, :math:`[0, 255]`"
:obj:`bbox` [#coco_bbox_1]_, ":math:`(R, 4)`", :obj:`float32`, \
":math:`(y_{min}, x_{min}, y_{max}, x_{max})`"
:obj:`label` [#coco_bbox_1]_, ":math:`(R,)`", :obj:`int32`, \
":math:`[0, #fg\_class - 1]`"
:obj:`area` [#coco_bbox_1]_ [#coco_bbox_2]_, ":math:`(R,)`", \
:obj:`float32`, --
:obj:`crowded` [#coco_bbox_3]_, ":math:`(R,)`", :obj:`bool`, --

.. [#coco_bbox_1] If :obj:`use_crowded = True`, :obj:`bbox`, \
:obj:`label` and :obj:`area` contain crowded instances.
.. [#coco_bbox_2] :obj:`area` is available \
if :obj:`return_area = True`.
.. [#coco_bbox_3] :obj:`crowded` is available \
if :obj:`return_crowded = True`.

When there are more than ten objects from the same category,
bounding boxes correspond to crowd of instances instead of individual
instances. Please see more detail in the Fig. 12 (e) of the summary
paper [#]_.

.. [#] Tsung-Yi Lin, Michael Maire, Serge Belongie, Lubomir Bourdev, \
Ross Girshick, James Hays, Pietro Perona, Deva Ramanan, \
C. Lawrence Zitnick, Piotr Dollar.
`Microsoft COCO: Common Objects in Context \
<https://arxiv.org/abs/1405.0312>`_. arXiv 2014.

"""

def __init__(self, data_dir='auto', split='train',
use_crowded=False, return_area=False, return_crowded=False):
super(COCOBboxDataset, self).__init__()
self.use_crowded = use_crowded
if split in ['val', 'minival', 'valminusminival']:
img_split = 'val'
else:
img_split = 'train'
if data_dir == 'auto':
data_dir = get_coco(split, img_split)

self.img_root = os.path.join(
data_dir, 'images', '{}2014'.format(img_split))
anno_path = os.path.join(
data_dir, 'annotations', 'instances_{}2014.json'.format(split))

self.data_dir = data_dir
annos = json.load(open(anno_path, 'r'))

self.id_to_prop = {}
for prop in annos['images']:
self.id_to_prop[prop['id']] = prop
self.ids = sorted(list(self.id_to_prop.keys()))

self.cat_ids = [cat['id'] for cat in annos['categories']]

self.id_to_anno = defaultdict(list)
for anno in annos['annotations']:
self.id_to_anno[anno['image_id']].append(anno)

self.add_getter('img', self._get_image)
self.add_getter(['bbox', 'label', 'area', 'crowded'],
self._get_annotations)

keys = ('img', 'bbox', 'label')
if return_area:
keys += ('area',)
if return_crowded:
keys += ('crowded',)
self.keys = keys

def __len__(self):
return len(self.ids)

def _get_image(self, i):
img_path = os.path.join(
self.img_root, self.id_to_prop[self.ids[i]]['file_name'])
img = utils.read_image(img_path, dtype=np.float32, color=True)
return img

def _get_annotations(self, i):
# List[{'segmentation', 'area', 'iscrowd',
# 'image_id', 'bbox', 'category_id', 'id'}]
annotation = self.id_to_anno[self.ids[i]]
bbox = np.array([ann['bbox'] for ann in annotation],
dtype=np.float32)
if len(bbox) == 0:
bbox = np.zeros((0, 4), dtype=np.float32)
# (x, y, width, height) -> (x_min, y_min, x_max, y_max)
bbox[:, 2] = bbox[:, 0] + bbox[:, 2]
bbox[:, 3] = bbox[:, 1] + bbox[:, 3]
# (x_min, y_min, x_max, y_max) -> (y_min, x_min, y_max, x_max)
bbox = bbox[:, [1, 0, 3, 2]]

label = np.array([self.cat_ids.index(ann['category_id'])
for ann in annotation], dtype=np.int32)

area = np.array([ann['area']
for ann in annotation], dtype=np.float32)

crowded = np.array([ann['iscrowd']
for ann in annotation], dtype=np.bool)

# Remove invalid boxes
bbox_area = np.prod(bbox[:, 2:] - bbox[:, :2], axis=1)
keep_mask = np.logical_and(bbox[:, 0] <= bbox[:, 2],
bbox[:, 1] <= bbox[:, 3])
keep_mask = np.logical_and(keep_mask, bbox_area > 0)

if not self.use_crowded:
keep_mask = np.logical_and(keep_mask, np.logical_not(crowded))

bbox = bbox[keep_mask]
label = label[keep_mask]
area = area[keep_mask]
crowded = crowded[keep_mask]
return bbox, label, area, crowded
129 changes: 129 additions & 0 deletions chainercv/datasets/coco/coco_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import os

from chainer.dataset import download

from chainercv import utils


root = 'pfnet/chainercv/coco'
img_urls = {
'train': 'http://msvocds.blob.core.windows.net/coco2014/train2014.zip',
'val': 'http://msvocds.blob.core.windows.net/coco2014/val2014.zip'
}
anno_urls = {
'train': 'http://msvocds.blob.core.windows.net/annotations-1-0-3/'
'instances_train-val2014.zip',
'val': 'http://msvocds.blob.core.windows.net/annotations-1-0-3/'
'instances_train-val2014.zip',
'valminusminival': 'https://dl.dropboxusercontent.com/s/s3tw5zcg7395368/'
'instances_valminusminival2014.json.zip',
'minival': 'https://dl.dropboxusercontent.com/s/o43o90bna78omob/'
'instances_minival2014.json.zip'
}


def get_coco(split, img_split):
url = img_urls[img_split]
data_dir = download.get_dataset_directory(root)
img_root = os.path.join(data_dir, 'images')
created_img_root = os.path.join(img_root, '{}2014'.format(img_split))
annos_root = os.path.join(data_dir, 'annotations')
anno_path = os.path.join(annos_root, 'instances_{}2014.json'.format(split))
if not os.path.exists(created_img_root):
download_file_path = utils.cached_download(url)
ext = os.path.splitext(url)[1]
utils.extractall(download_file_path, img_root, ext)
if not os.path.exists(anno_path):
anno_url = anno_urls[split]
download_file_path = utils.cached_download(anno_url)
ext = os.path.splitext(anno_url)[1]
utils.extractall(download_file_path, annos_root, ext)
return data_dir


# How you can get the labels
# >>> from pycocotools.coco import COCO
# >>> coco = COCO('instances_train2014.json')
# >>> cat_dict = coco.loadCats(coco.getCatIds())
# >>> coco_bbox_label_names = [c['name'] for c in cat_dict]
coco_bbox_label_names = (
'person',
'bicycle',
'car',
'motorcycle',
'airplane',
'bus',
'train',
'truck',
'boat',
'traffic light',
'fire hydrant',
'stop sign',
'parking meter',
'bench',
'bird',
'cat',
'dog',
'horse',
'sheep',
'cow',
'elephant',
'bear',
'zebra',
'giraffe',
'backpack',
'umbrella',
'handbag',
'tie',
'suitcase',
'frisbee',
'skis',
'snowboard',
'sports ball',
'kite',
'baseball bat',
'baseball glove',
'skateboard',
'surfboard',
'tennis racket',
'bottle',
'wine glass',
'cup',
'fork',
'knife',
'spoon',
'bowl',
'banana',
'apple',
'sandwich',
'orange',
'broccoli',
'carrot',
'hot dog',
'pizza',
'donut',
'cake',
'chair',
'couch',
'potted plant',
'bed',
'dining table',
'toilet',
'tv',
'laptop',
'mouse',
'remote',
'keyboard',
'cell phone',
'microwave',
'oven',
'toaster',
'sink',
'refrigerator',
'book',
'clock',
'vase',
'scissors',
'teddy bear',
'hair drier',
'toothbrush')
9 changes: 9 additions & 0 deletions docs/source/reference/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ CityscapesTestImageDataset
~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: CityscapesTestImageDataset


CUB
---

Expand All @@ -63,6 +64,14 @@ CUBPointDataset
.. autoclass:: CUBPointDataset


MS COCO
-------

COCOBboxDataset
~~~~~~~~~~~~~~~
.. autoclass:: COCOBboxDataset


OnlineProducts
--------------

Expand Down
54 changes: 54 additions & 0 deletions tests/datasets_tests/coco_tests/test_coco_bbox_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import unittest

import numpy as np

from chainer import testing
from chainer.testing import attr

from chainercv.datasets import coco_bbox_label_names
from chainercv.datasets import COCOBboxDataset
from chainercv.utils import assert_is_bbox_dataset


@testing.parameterize(*testing.product({
'split': ['train', 'val', 'minival', 'valminusminival'],
'use_crowded': [False, True],
'return_area': [False, True],
'return_crowded': [False, True]
}))
class TestCOCOBboxDataset(unittest.TestCase):

def setUp(self):
self.dataset = COCOBboxDataset(
split=self.split,
use_crowded=self.use_crowded, return_area=self.return_area,
return_crowded=self.return_crowded)

@attr.slow
def test_coco_bbox_dataset(self):
assert_is_bbox_dataset(
self.dataset, len(coco_bbox_label_names), n_example=30)

if self.return_area:
for _ in range(10):
i = np.random.randint(0, len(self.dataset))
_, bbox, _, area = self.dataset[i][:4]
self.assertIsInstance(area, np.ndarray)
self.assertEqual(area.dtype, np.float32)
self.assertEqual(area.shape, (bbox.shape[0],))

if self.return_crowded:
for _ in range(10):
i = np.random.randint(0, len(self.dataset))
example = self.dataset[i]
crowded = example[-1]
bbox = example[1]
self.assertIsInstance(crowded, np.ndarray)
self.assertEqual(crowded.dtype, np.bool)
self.assertEqual(crowded.shape, (bbox.shape[0],))

if not self.use_crowded:
np.testing.assert_equal(crowded, 0)


testing.run_module(__name__, __file__)