Skip to content
This repository has been archived by the owner on Jul 2, 2021. It is now read-only.

add sort_by_score in vis_bbox #801

Merged
merged 2 commits into from Feb 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 11 additions & 1 deletion chainercv/visualizations/vis_bbox.py
Expand Up @@ -4,7 +4,8 @@


def vis_bbox(img, bbox, label=None, score=None, label_names=None,
instance_colors=None, alpha=1., linewidth=3., ax=None):
instance_colors=None, alpha=1., linewidth=3.,
sort_by_score=True, ax=None):
"""Visualize bounding boxes inside image.

Example:
Expand Down Expand Up @@ -60,6 +61,8 @@ def vis_bbox(img, bbox, label=None, score=None, label_names=None,
alpha (float): The value which determines transparency of the
bounding boxes. The range of this value is :math:`[0, 1]`.
linewidth (float): The thickness of the edges of the bounding boxes.
sort_by_score (bool): When :obj:`True`, instances with high scores
are always visualized in front of instances with low scores.
ax (matplotlib.axes.Axis): The visualization is displayed on this
axis. If this is :obj:`None` (default), a new axis is created.

Expand All @@ -75,6 +78,13 @@ def vis_bbox(img, bbox, label=None, score=None, label_names=None,
if score is not None and not len(bbox) == len(score):
raise ValueError('The length of score must be same as that of bbox')

if sort_by_score and score is not None:
order = np.argsort(score)
bbox = bbox[order]
score = score[order]
if label is not None:
label = label[order]

# Returns newly instantiated matplotlib.axes.Axes object if ax is None
ax = vis_image(img, ax=ax)

Expand Down
77 changes: 40 additions & 37 deletions tests/visualizations_tests/test_vis_bbox.py
Expand Up @@ -15,39 +15,40 @@


@testing.parameterize(
{
'n_bbox': 3, 'label': (0, 1, 2), 'score': (0, 0.5, 1),
'label_names': ('c0', 'c1', 'c2')},
{
'n_bbox': 3, 'label': (0, 1, 2), 'score': None,
'label_names': ('c0', 'c1', 'c2')},
{
'n_bbox': 3, 'label': (0, 1, 2), 'score': (0, 0.5, 1),
'label_names': None},
{
'n_bbox': 3, 'label': None, 'score': (0, 0.5, 1),
'label_names': ('c0', 'c1', 'c2')},
{
'n_bbox': 3, 'label': None, 'score': (0, 0.5, 1),
'label_names': None},
{
'n_bbox': 3, 'label': None, 'score': None,
'label_names': None},
{
'n_bbox': 3, 'label': (0, 1, 1), 'score': (0, 0.5, 1),
'label_names': ('c0', 'c1', 'c2')},
{
'n_bbox': 0, 'label': (), 'score': (),
'label_names': ('c0', 'c1', 'c2')},
{
'n_bbox': 3, 'label': (0, 1, 2), 'score': (0, 0.5, 1),
'label_names': ('c0', 'c1', 'c2'), 'no_img': True},
{
'n_bbox': 3, 'label': (0, 1, 2), 'score': (0, 0.5, 1),
'label_names': ('c0', 'c1', 'c2'),
'instance_colors': [
(255, 0, 0), (0, 255, 0), (0, 0, 255), (100, 100, 100)]},
)
*testing.product_dict([
{
'n_bbox': 3, 'label': (0, 1, 2), 'score': (0, 0.5, 1),
'label_names': ('c0', 'c1', 'c2')},
{
'n_bbox': 3, 'label': (0, 1, 2), 'score': None,
'label_names': ('c0', 'c1', 'c2')},
{
'n_bbox': 3, 'label': (0, 1, 2), 'score': (0, 0.5, 1),
'label_names': None},
{
'n_bbox': 3, 'label': None, 'score': (0, 0.5, 1),
'label_names': ('c0', 'c1', 'c2')},
{
'n_bbox': 3, 'label': None, 'score': (0, 0.5, 1),
'label_names': None},
{
'n_bbox': 3, 'label': None, 'score': None,
'label_names': None},
{
'n_bbox': 3, 'label': (0, 1, 1), 'score': (0, 0.5, 1),
'label_names': ('c0', 'c1', 'c2')},
{
'n_bbox': 0, 'label': (), 'score': (),
'label_names': ('c0', 'c1', 'c2')},
{
'n_bbox': 3, 'label': (0, 1, 2), 'score': (0, 0.5, 1),
'label_names': ('c0', 'c1', 'c2'), 'no_img': True},
{
'n_bbox': 3, 'label': (0, 1, 2), 'score': (0, 0.5, 1),
'label_names': ('c0', 'c1', 'c2'),
'instance_colors': [
(255, 0, 0), (0, 255, 0), (0, 0, 255), (100, 100, 100)]},
], [{'sort_by_score': False}, {'sort_by_score': True}]))
@unittest.skipUnless(_available, 'Matplotlib is not installed')
class TestVisBbox(unittest.TestCase):

Expand All @@ -69,12 +70,13 @@ def test_vis_bbox(self):
ax = vis_bbox(
self.img, self.bbox, self.label, self.score,
label_names=self.label_names,
instance_colors=self.instance_colors)
instance_colors=self.instance_colors,
sort_by_score=self.sort_by_score)

self.assertIsInstance(ax, matplotlib.axes.Axes)


@testing.parameterize(
@testing.parameterize(*testing.product_dict([
{
'n_bbox': 3, 'label': (0, 1), 'score': (0, 0.5, 1),
'label_names': ('c0', 'c1', 'c2')},
Expand All @@ -95,7 +97,7 @@ def test_vis_bbox(self):
{
'n_bbox': 3, 'label': (-1, 1, 2), 'score': (0, 0.5, 1),
'label_names': ('c0', 'c1', 'c2')},
)
], [{'sort_by_score': False}, {'sort_by_score': True}]))
@unittest.skipUnless(_available, 'Matplotlib is not installed')
class TestVisBboxInvalidInputs(unittest.TestCase):

Expand All @@ -114,7 +116,8 @@ def test_vis_bbox_invalid_inputs(self):
vis_bbox(
self.img, self.bbox, self.label, self.score,
label_names=self.label_names,
instance_colors=self.instance_colors)
instance_colors=self.instance_colors,
sort_by_score=self.sort_by_score)


testing.run_module(__name__, __file__)