Skip to content
This repository has been archived by the owner on Jul 2, 2021. It is now read-only.

Commit

Permalink
change assert_is_point, assert_is_point_dataset, CUB datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
yuyu2172 committed Feb 21, 2019
1 parent 3f2a580 commit beb4777
Show file tree
Hide file tree
Showing 10 changed files with 173 additions and 148 deletions.
2 changes: 1 addition & 1 deletion chainercv/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from chainercv.datasets.coco.coco_utils import coco_semantic_segmentation_label_colors # NOQA
from chainercv.datasets.coco.coco_utils import coco_semantic_segmentation_label_names # NOQA
from chainercv.datasets.cub.cub_label_dataset import CUBLabelDataset # NOQA
from chainercv.datasets.cub.cub_point_dataset import CUBPointDataset # NOQA
from chainercv.datasets.cub.cub_keypoint_dataset import CUBKeypointDataset # NOQA
from chainercv.datasets.cub.cub_utils import cub_label_names # NOQA
from chainercv.datasets.directory_parsing_label_dataset import directory_parsing_label_names # NOQA
from chainercv.datasets.directory_parsing_label_dataset import DirectoryParsingLabelDataset # NOQA
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from chainercv import utils


class CUBPointDataset(CUBDatasetBase):
class CUBKeypointDataset(CUBDatasetBase):

"""`Caltech-UCSD Birds-200-2011`_ dataset with annotated points.
Expand All @@ -17,7 +17,7 @@ class CUBPointDataset(CUBDatasetBase):
data_dir (string): Path to the root of the training data. If this is
:obj:`auto`, this class will automatically download data for you
under :obj:`$CHAINER_DATASET_ROOT/pfnet/chainercv/cub`.
return_bb (bool): If :obj:`True`, this returns a bounding box
return_bbox (bool): If :obj:`True`, this returns a bounding box
around a bird. The default value is :obj:`False`.
prob_map_dir (string): Path to the root of the probability maps.
If this is :obj:`auto`, this class will automatically download data
Expand All @@ -33,29 +33,29 @@ class CUBPointDataset(CUBDatasetBase):
:obj:`img`, ":math:`(3, H, W)`", :obj:`float32`, \
"RGB, :math:`[0, 255]`"
:obj:`point`, ":math:`(P, 2)`", :obj:`float32`, ":math:`(y, x)`"
:obj:`mask`, ":math:`(P,)`", :obj:`bool`, --
:obj:`bb` [#cub_point_1]_, ":math:`(4,)`", :obj:`float32`, \
:obj:`point`, ":math:`(1, 15, 2)`", :obj:`float32`, ":math:`(y, x)`"
:obj:`visible`, ":math:`(1, 15)`", :obj:`bool`, --
:obj:`bbox` [#cub_point_1]_, ":math:`(1, 4)`", :obj:`float32`, \
":math:`(y_{min}, x_{min}, y_{max}, x_{max})`"
:obj:`prob_map` [#cub_point_2]_, ":math:`(H, W)`", :obj:`float32`, \
":math:`[0, 1]`"
.. [#cub_point_1] :obj:`bb` indicates the location of a bird. \
It is available if :obj:`return_bb = True`.
It is available if :obj:`return_bbox = True`.
.. [#cub_point_2] :obj:`prob_map` indicates how likey a bird is located \
at each the pixel. \
It is available if :obj:`return_prob_map = True`.
"""

def __init__(self, data_dir='auto', return_bb=False,
def __init__(self, data_dir='auto', return_bbox=False,
prob_map_dir='auto', return_prob_map=False):
super(CUBPointDataset, self).__init__(data_dir, prob_map_dir)
super(CUBKeypointDataset, self).__init__(data_dir, prob_map_dir)

# load point
parts_loc_file = os.path.join(self.data_dir, 'parts', 'part_locs.txt')
self._point_dict = collections.defaultdict(list)
self._mask_dict = collections.defaultdict(list)
self._visible_dict = collections.defaultdict(list)
for loc in open(parts_loc_file):
values = loc.split()
id_ = int(values[0]) - 1
Expand All @@ -65,14 +65,14 @@ def __init__(self, data_dir='auto', return_bb=False,
mask = bool(int(values[4]))

self._point_dict[id_].append(point)
self._mask_dict[id_].append(mask)
self._visible_dict[id_].append(mask)

self.add_getter(('img', 'point', 'mask'),
self.add_getter(('img', 'point', 'visible'),
self._get_img_and_annotations)

keys = ('img', 'point', 'mask')
if return_bb:
keys += ('bb',)
keys = ('img', 'point', 'visible')
if return_bbox:
keys += ('bbox',)
if return_prob_map:
keys += ('prob_map',)
self.keys = keys
Expand All @@ -82,12 +82,12 @@ def _get_img_and_annotations(self, i):
os.path.join(self.data_dir, 'images', self.paths[i]),
color=True)

point = np.array(self._point_dict[i], dtype=np.float32)
mask = np.array(self._mask_dict[i], dtype=np.bool)
pnt = np.array(self._point_dict[i], dtype=np.float32)
vsble = np.array(self._visible_dict[i], dtype=np.bool)

_, H, W = img.shape
invalid = np.logical_or(
np.logical_or(point[:, 0] > H, point[:, 1] > W),
np.any(point < 0, axis=1))
mask[invalid] = False
return img, point, mask
invisible = np.logical_or(
np.logical_or(pnt[:, 0] > H, pnt[:, 1] > W),
np.any(pnt < 0, axis=1))
vsble[invisible] = False
return img, pnt[None], vsble[None]
12 changes: 6 additions & 6 deletions chainercv/datasets/cub/cub_label_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class CUBLabelDataset(CUBDatasetBase):
data_dir (string): Path to the root of the training data. If this is
:obj:`auto`, this class will automatically download data for you
under :obj:`$CHAINER_DATASET_ROOT/pfnet/chainercv/cub`.
return_bb (bool): If :obj:`True`, this returns a bounding box
return_bbox (bool): If :obj:`True`, this returns a bounding box
around a bird. The default value is :obj:`False`.
prob_map_dir (string): Path to the root of the probability maps.
If this is :obj:`auto`, this class will automatically download data
Expand All @@ -33,19 +33,19 @@ class CUBLabelDataset(CUBDatasetBase):
:obj:`img`, ":math:`(3, H, W)`", :obj:`float32`, \
"RGB, :math:`[0, 255]`"
:obj:`label`, scalar, :obj:`int32`, ":math:`[0, \#class - 1]`"
:obj:`bb` [#cub_label_1]_, ":math:`(4,)`", :obj:`float32`, \
:obj:`bbox` [#cub_label_1]_, ":math:`(1, 4)`", :obj:`float32`, \
":math:`(y_{min}, x_{min}, y_{max}, x_{max})`"
:obj:`prob_map` [#cub_label_2]_, ":math:`(H, W)`", :obj:`float32`, \
":math:`[0, 1]`"
.. [#cub_label_1] :obj:`bb` indicates the location of a bird. \
It is available if :obj:`return_bb = True`.
It is available if :obj:`return_bbox = True`.
.. [#cub_label_2] :obj:`prob_map` indicates how likey a bird is located \
at each the pixel. \
It is available if :obj:`return_prob_map = True`.
"""

def __init__(self, data_dir='auto', return_bb=False,
def __init__(self, data_dir='auto', return_bbox=False,
prob_map_dir='auto', return_prob_map=False):
super(CUBLabelDataset, self).__init__(data_dir, prob_map_dir)

Expand All @@ -59,8 +59,8 @@ def __init__(self, data_dir='auto', return_bb=False,
self.add_getter('label', self._get_label)

keys = ('img', 'label')
if return_bb:
keys += ('bb',)
if return_bbox:
keys += ('bbox',)
if return_prob_map:
keys += ('prob_map',)
self.keys = keys
Expand Down
6 changes: 3 additions & 3 deletions chainercv/datasets/cub/cub_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,14 @@ def __init__(self, data_dir='auto', prob_map_dir='auto'):
os.path.join(self.prob_map_dir, os.path.splitext(path)[0] + '.png')
for path in self.paths]

self.add_getter('bb', self._get_bb)
self.add_getter('bbox', self._get_bbox)
self.add_getter('prob_map', self._get_prob_map)

def __len__(self):
return len(self.paths)

def _get_bb(self, i):
return self.bbs[i]
def _get_bbox(self, i):
return self.bbs[i][None]

def _get_prob_map(self, i):
prob_map = utils.read_label(self.prob_map_paths[i], dtype=np.uint8)
Expand Down
63 changes: 36 additions & 27 deletions chainercv/utils/testing/assertions/assert_is_point.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,52 @@
import numpy as np


def assert_is_point(point, mask=None, size=None):
def assert_is_point(point, visible=None, size=None, n_point=None):
"""Checks if points satisfy the format.
This function checks if given points satisfy the format and
raises an :class:`AssertionError` when the points violate the convention.
Args:
point (~numpy.ndarray): Points to be checked.
mask (~numpy.ndarray): A mask of the points.
If this is :obj:`None`, all points are regarded as valid.
visible (~numpy.ndarray): Visibility of the points.
If this is :obj:`None`, all points are regarded as visible.
size (tuple of ints): The size of an image.
If this argument is specified,
the coordinates of valid points are checked to be within the image.
the coordinates of visible points are checked to be within the image.
n_point (int): If specified, the number of points in each object is
expected to be :obj:`n_point`.
"""

assert isinstance(point, np.ndarray), \
'point must be a numpy.ndarray.'
assert point.dtype == np.float32, \
'The type of point must be numpy.float32.'
assert point.shape[1:] == (2,), \
'The shape of point must be (*, 2).'
for i, pnt in enumerate(point):
assert isinstance(pnt, np.ndarray), \
'pnt must be a numpy.ndarray.'
assert pnt.dtype == np.float32, \
'The type of pnt must be numpy.float32.'
assert pnt.shape[1:] == (2,), \
'The shape of pnt must be (*, 2).'
if n_point is not None:
assert pnt.shape[0] == n_point, \
'The number of points should always be n_point'

if mask is not None:
assert isinstance(mask, np.ndarray), \
'a mask of points must be a numpy.ndarray.'
assert mask.dtype == np.bool, \
'The type of mask must be numpy.bool.'
assert mask.ndim == 1, \
'The dimensionality of a mask must be one.'
assert mask.shape[0] == point.shape[0], \
'The size of the first axis should be the same for ' \
'corresponding point and mask.'
valid_point = point[mask]
else:
valid_point = point
if visible is not None:
assert len(point) == len(visible), \
'The length of point and visible should be the same.'
vsble = visible[i]
assert isinstance(vsble, np.ndarray), \
'pnt should be a numpy.ndarray.'
assert vsble.dtype == np.bool, \
'The type of visible must be numpy.bool.'
assert vsble.ndim == 1, \
'The dimensionality of a visible must be one.'
assert vsble.shape[0] == pnt.shape[0], \
'The size of the first axis should be the same for ' \
'corresponding pnt and vsble.'
visible_pnt = pnt[vsble]
else:
visible_pnt = pnt

if size is not None:
assert (valid_point >= 0).all() and (valid_point <= size).all(),\
'The coordinates of valid points ' \
'should not exceed the size of image.'
if size is not None:
assert (visible_pnt >= 0).all() and (visible_pnt <= size).all(),\
'The coordinates of visible points ' \
'should not exceed the size of image.'
28 changes: 12 additions & 16 deletions chainercv/utils/testing/assertions/assert_is_point_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


def assert_is_point_dataset(dataset, n_point=None, n_example=None,
no_mask=False):
no_visible=False):
"""Checks if a dataset satisfies the point dataset API.
This function checks if a given dataset satisfies the point dataset
Expand All @@ -23,9 +23,9 @@ def assert_is_point_dataset(dataset, n_point=None, n_example=None,
If this argument is specified, this function picks
examples ramdomly and checks them. Otherwise,
this function checks all examples.
no_mask (bool): If :obj:`True`, we assume that
point mask is always not contained.
If :obj:`False`, point mask may or may not be contained.
no_visible (bool): If :obj:`True`, we assume that
visibility mask is always not contained.
If :obj:`False`, point visible may or may not be contained.
"""

Expand All @@ -34,26 +34,22 @@ def assert_is_point_dataset(dataset, n_point=None, n_example=None,
if n_example:
for _ in six.moves.range(n_example):
i = np.random.randint(0, len(dataset))
_check_example(dataset[i], n_point, no_mask)
_check_example(dataset[i], n_point, no_visible)
else:
for i in six.moves.range(len(dataset)):
_check_example(dataset[i], n_point, no_mask)
_check_example(dataset[i], n_point, no_visible)


def _check_example(example, n_point=None, no_mask=False):
def _check_example(example, n_point=None, no_visible=False):
assert len(example) >= 2, \
'Each example must have at least two elements:' \
'img, point (mask is optional).'
'img, point (visible is optional).'

if len(example) == 2 or no_mask:
if len(example) == 2 or no_visible:
img, point = example[:2]
mask = None
visible = None
elif len(example) >= 3:
img, point, mask = example[:3]
img, point, visible = example[:3]

assert_is_image(img, color=True)
assert_is_point(point, mask, img.shape[1:])

if n_point is not None:
assert point.shape[0] == n_point, \
'The number of points is different from the expected number.'
assert_is_point(point, visible, img.shape[1:], n_point)
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,33 @@
from chainer import testing
from chainer.testing import attr

from chainercv.datasets import CUBPointDataset
from chainercv.datasets import CUBKeypointDataset
from chainercv.utils import assert_is_bbox
from chainercv.utils import assert_is_point_dataset


@testing.parameterize(*testing.product({
'return_bb': [True, False],
'return_bbox': [True, False],
'return_prob_map': [True, False]}
))
class TestCUBPointDataset(unittest.TestCase):
class TestCUBKeypointDataset(unittest.TestCase):

def setUp(self):
self.dataset = CUBPointDataset(return_bb=self.return_bb,
return_prob_map=self.return_prob_map)
self.dataset = CUBKeypointDataset(return_bbox=self.return_bbox,
return_prob_map=self.return_prob_map)

@attr.slow
def test_cub_point_dataset(self):
assert_is_point_dataset(
self.dataset, n_point=15, n_example=10)

idx = np.random.choice(np.arange(10))
if self.return_bb:
if self.return_bbox:
if self.return_prob_map:
bb = self.dataset[idx][-2]
bbox = self.dataset[idx][-2]
else:
bb = self.dataset[idx][-1]
assert_is_bbox(bb[np.newaxis])
bbox = self.dataset[idx][-1]
assert_is_bbox(bbox)
if self.return_prob_map:
img = self.dataset[idx][0]
prob_map = self.dataset[idx][-1]
Expand Down
10 changes: 5 additions & 5 deletions tests/datasets_tests/cub_tests/test_cub_label_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,23 @@


@testing.parameterize(*testing.product({
'return_bb': [True, False],
'return_bbox': [True, False],
'return_prob_map': [True, False]
}))
class TestCUBLabelDataset(unittest.TestCase):

def setUp(self):
self.dataset = CUBLabelDataset(
return_bb=self.return_bb, return_prob_map=self.return_prob_map)
return_bbox=self.return_bbox, return_prob_map=self.return_prob_map)

@attr.slow
def test_cub_label_dataset(self):
assert_is_label_dataset(
self.dataset, len(cub_label_names), n_example=10)
idx = np.random.choice(np.arange(10))
if self.return_bb:
bb = self.dataset[idx][2]
assert_is_bbox(bb[np.newaxis])
if self.return_bbox:
bbox = self.dataset[idx][2]
assert_is_bbox(bbox)
if self.return_prob_map:
img = self.dataset[idx][0]
prob_map = self.dataset[idx][-1]
Expand Down

0 comments on commit beb4777

Please sign in to comment.