This repository has been archived by the owner on Jul 2, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 306
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
change assert_is_point, assert_is_point_dataset, CUB datasets
- Loading branch information
Showing
10 changed files
with
173 additions
and
148 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,43 +1,52 @@ | ||
import numpy as np | ||
|
||
|
||
def assert_is_point(point, mask=None, size=None): | ||
def assert_is_point(point, visible=None, size=None, n_point=None): | ||
"""Checks if points satisfy the format. | ||
This function checks if given points satisfy the format and | ||
raises an :class:`AssertionError` when the points violate the convention. | ||
Args: | ||
point (~numpy.ndarray): Points to be checked. | ||
mask (~numpy.ndarray): A mask of the points. | ||
If this is :obj:`None`, all points are regarded as valid. | ||
visible (~numpy.ndarray): Visibility of the points. | ||
If this is :obj:`None`, all points are regarded as visible. | ||
size (tuple of ints): The size of an image. | ||
If this argument is specified, | ||
the coordinates of valid points are checked to be within the image. | ||
the coordinates of visible points are checked to be within the image. | ||
n_point (int): If specified, the number of points in each object is | ||
expected to be :obj:`n_point`. | ||
""" | ||
|
||
assert isinstance(point, np.ndarray), \ | ||
'point must be a numpy.ndarray.' | ||
assert point.dtype == np.float32, \ | ||
'The type of point must be numpy.float32.' | ||
assert point.shape[1:] == (2,), \ | ||
'The shape of point must be (*, 2).' | ||
for i, pnt in enumerate(point): | ||
assert isinstance(pnt, np.ndarray), \ | ||
'pnt must be a numpy.ndarray.' | ||
assert pnt.dtype == np.float32, \ | ||
'The type of pnt must be numpy.float32.' | ||
assert pnt.shape[1:] == (2,), \ | ||
'The shape of pnt must be (*, 2).' | ||
if n_point is not None: | ||
assert pnt.shape[0] == n_point, \ | ||
'The number of points should always be n_point' | ||
|
||
if mask is not None: | ||
assert isinstance(mask, np.ndarray), \ | ||
'a mask of points must be a numpy.ndarray.' | ||
assert mask.dtype == np.bool, \ | ||
'The type of mask must be numpy.bool.' | ||
assert mask.ndim == 1, \ | ||
'The dimensionality of a mask must be one.' | ||
assert mask.shape[0] == point.shape[0], \ | ||
'The size of the first axis should be the same for ' \ | ||
'corresponding point and mask.' | ||
valid_point = point[mask] | ||
else: | ||
valid_point = point | ||
if visible is not None: | ||
assert len(point) == len(visible), \ | ||
'The length of point and visible should be the same.' | ||
vsble = visible[i] | ||
assert isinstance(vsble, np.ndarray), \ | ||
'pnt should be a numpy.ndarray.' | ||
assert vsble.dtype == np.bool, \ | ||
'The type of visible must be numpy.bool.' | ||
assert vsble.ndim == 1, \ | ||
'The dimensionality of a visible must be one.' | ||
assert vsble.shape[0] == pnt.shape[0], \ | ||
'The size of the first axis should be the same for ' \ | ||
'corresponding pnt and vsble.' | ||
visible_pnt = pnt[vsble] | ||
else: | ||
visible_pnt = pnt | ||
|
||
if size is not None: | ||
assert (valid_point >= 0).all() and (valid_point <= size).all(),\ | ||
'The coordinates of valid points ' \ | ||
'should not exceed the size of image.' | ||
if size is not None: | ||
assert (visible_pnt >= 0).all() and (visible_pnt <= size).all(),\ | ||
'The coordinates of visible points ' \ | ||
'should not exceed the size of image.' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.