Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Datumaro] Label remapping transform #1233

Merged
merged 4 commits into from
Mar 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions datumaro/datumaro/components/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,8 @@ def sources(self):
return self._sources

def _save_branch_project(self, extractor, save_dir=None):
extractor = Dataset.from_extractors(extractor) # apply lazy transforms

# NOTE: probably this function should be in the ViewModel layer
save_dir = osp.abspath(save_dir)
if save_dir:
Expand Down
120 changes: 118 additions & 2 deletions datumaro/datumaro/plugins/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@
#
# SPDX-License-Identifier: MIT

from enum import Enum
import logging as log
import os.path as osp
import random

import pycocotools.mask as mask_utils

from datumaro.components.extractor import (Transform, AnnotationType,
RleMask, Polygon, Bbox)
RleMask, Polygon, Bbox,
LabelCategories, MaskCategories, PointsCategories
)
from datumaro.components.cli_plugin import CliPlugin
import datumaro.util.mask_tools as mask_tools
from datumaro.util.annotation_tools import find_group_leader, find_instances
Expand Down Expand Up @@ -46,7 +49,7 @@ def crop_segments(cls, segment_anns, img_width, img_height):
segments.append(s.points)
elif s.type == AnnotationType.mask:
if isinstance(s, RleMask):
rle = s._rle
rle = s.rle
else:
rle = mask_tools.mask_to_rle(s.image)
segments.append(rle)
Expand Down Expand Up @@ -365,3 +368,116 @@ def transform_item(self, item):
if item.has_image and item.image.filename:
name = osp.splitext(item.image.filename)[0]
return self.wrap_item(item, id=name)

class RemapLabels(Transform, CliPlugin):
DefaultAction = Enum('DefaultAction', ['keep', 'delete'])

@staticmethod
def _split_arg(s):
parts = s.split(':')
if len(parts) != 2:
import argparse
raise argparse.ArgumentTypeError()
return (parts[0], parts[1])

@classmethod
def build_cmdline_parser(cls, **kwargs):
parser = super().build_cmdline_parser(**kwargs)
parser.add_argument('-l', '--label', action='append',
type=cls._split_arg, dest='mapping',
help="Label in the form of: '<src>:<dst>' (repeatable)")
parser.add_argument('--default',
choices=[a.name for a in cls.DefaultAction],
default=cls.DefaultAction.keep.name,
help="Action for unspecified labels")
return parser

def __init__(self, extractor, mapping, default=None):
super().__init__(extractor)

assert isinstance(default, (str, self.DefaultAction))
if isinstance(default, str):
default = self.DefaultAction[default]

assert isinstance(mapping, (dict, list))
if isinstance(mapping, list):
mapping = dict(mapping)

self._categories = {}

src_label_cat = self._extractor.categories().get(AnnotationType.label)
if src_label_cat is not None:
self._make_label_id_map(src_label_cat, mapping, default)

src_mask_cat = self._extractor.categories().get(AnnotationType.mask)
if src_mask_cat is not None:
assert src_label_cat is not None
dst_mask_cat = MaskCategories(attributes=src_mask_cat.attributes)
dst_mask_cat.colormap = {
id: src_mask_cat.colormap[id]
for id, _ in enumerate(src_label_cat.items)
if self._map_id(id) or id == 0
}
self._categories[AnnotationType.mask] = dst_mask_cat

src_points_cat = self._extractor.categories().get(AnnotationType.points)
if src_points_cat is not None:
assert src_label_cat is not None
dst_points_cat = PointsCategories(attributes=src_points_cat.attributes)
dst_points_cat.items = {
id: src_points_cat.items[id]
for id, item in enumerate(src_label_cat.items)
if self._map_id(id) or id == 0
}
self._categories[AnnotationType.points] = dst_points_cat

def _make_label_id_map(self, src_label_cat, label_mapping, default_action):
dst_label_cat = LabelCategories(attributes=src_label_cat.attributes)
id_mapping = {}
for src_index, src_label in enumerate(src_label_cat.items):
dst_label = label_mapping.get(src_label.name)
if not dst_label and default_action == self.DefaultAction.keep:
dst_label = src_label.name # keep unspecified as is
if not dst_label:
continue

dst_index = dst_label_cat.find(dst_label)[0]
if dst_index is None:
dst_label_cat.add(dst_label,
src_label.parent, src_label.attributes)
dst_index = dst_label_cat.find(dst_label)[0]
id_mapping[src_index] = dst_index

if log.getLogger().isEnabledFor(log.DEBUG):
log.debug("Label mapping:")
for src_id, src_label in enumerate(src_label_cat.items):
if id_mapping.get(src_id):
log.debug("#%s '%s' -> #%s '%s'",
src_id, src_label.name, id_mapping[src_id],
dst_label_cat.items[id_mapping[src_id]].name
)
else:
log.debug("#%s '%s' -> <deleted>", src_id, src_label.name)

self._map_id = lambda src_id: id_mapping.get(src_id, None)
self._categories[AnnotationType.label] = dst_label_cat

def categories(self):
return self._categories

def transform_item(self, item):
# TODO: provide non-inplace version
annotations = []
for ann in item.annotations:
if ann.type in { AnnotationType.label, AnnotationType.mask,
AnnotationType.points, AnnotationType.polygon,
AnnotationType.polyline, AnnotationType.bbox
} and ann.label is not None:
conv_label = self._map_id(ann.label)
if conv_label is not None:
ann._label = conv_label
annotations.append(ann)
else:
annotations.append(ann)
item._annotations = annotations
return item
25 changes: 12 additions & 13 deletions datumaro/datumaro/plugins/voc_format/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,13 @@ def _write_xml_bbox(bbox, parent_elem):
class _Converter:
def __init__(self, extractor, save_dir,
tasks=None, apply_colormap=True, save_images=False, label_map=None):
assert tasks is None or isinstance(tasks, (VocTask, list))
assert tasks is None or isinstance(tasks, (VocTask, list, set))
if tasks is None:
tasks = list(VocTask)
tasks = set(VocTask)
elif isinstance(tasks, VocTask):
tasks = [tasks]
tasks = {tasks}
else:
tasks = [t if t in VocTask else VocTask[t] for t in tasks]

tasks = set(t if t in VocTask else VocTask[t] for t in tasks)
self._tasks = tasks

self._extractor = extractor
Expand Down Expand Up @@ -259,10 +258,10 @@ def save_subsets(self):
if len(actions_elem) != 0:
obj_elem.append(actions_elem)

if set(self._tasks) & set([None,
if self._tasks & {None,
VocTask.detection,
VocTask.person_layout,
VocTask.action_classification]):
VocTask.action_classification}:
with open(osp.join(self._ann_dir, item.id + '.xml'), 'w') as f:
f.write(ET.tostring(root_elem,
encoding='unicode', pretty_print=True))
Expand Down Expand Up @@ -302,19 +301,19 @@ def save_subsets(self):
action_list[item.id] = None
segm_list[item.id] = None

if set(self._tasks) & set([None,
if self._tasks & {None,
VocTask.classification,
VocTask.detection,
VocTask.action_classification,
VocTask.person_layout]):
VocTask.person_layout}:
self.save_clsdet_lists(subset_name, clsdet_list)
if set(self._tasks) & set([None, VocTask.classification]):
if self._tasks & {None, VocTask.classification}:
self.save_class_lists(subset_name, class_lists)
if set(self._tasks) & set([None, VocTask.action_classification]):
if self._tasks & {None, VocTask.action_classification}:
self.save_action_lists(subset_name, action_list)
if set(self._tasks) & set([None, VocTask.person_layout]):
if self._tasks & {None, VocTask.person_layout}:
self.save_layout_lists(subset_name, layout_list)
if set(self._tasks) & set([None, VocTask.segmentation]):
if self._tasks & {None, VocTask.segmentation}:
self.save_segm_lists(subset_name, segm_list)

def save_action_lists(self, subset_name, action_list):
Expand Down
98 changes: 96 additions & 2 deletions datumaro/tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from unittest import TestCase

from datumaro.components.extractor import (Extractor, DatasetItem,
Mask, Polygon, PolyLine, Points, Bbox
Mask, Polygon, PolyLine, Points, Bbox, Label,
LabelCategories, MaskCategories, AnnotationType
)
from datumaro.util.test_utils import compare_datasets
import datumaro.util.mask_tools as mask_tools
import datumaro.plugins.transforms as transforms
from datumaro.util.test_utils import compare_datasets


class TransformsTest(TestCase):
Expand Down Expand Up @@ -361,3 +363,95 @@ def __iter__(self):
('train', -0.5),
('test', 1.5),
])

def test_remap_labels(self):
class SrcExtractor(Extractor):
def __iter__(self):
return iter([
DatasetItem(id=1, annotations=[
# Should be remapped
Label(1),
Bbox(1, 2, 3, 4, label=2),
Mask(image=np.array([1]), label=3),

# Should be kept
Polygon([1, 1, 2, 2, 3, 4], label=4),
PolyLine([1, 3, 4, 2, 5, 6], label=None)
]),
])

def categories(self):
label_cat = LabelCategories()
label_cat.add('label0')
label_cat.add('label1')
label_cat.add('label2')
label_cat.add('label3')
label_cat.add('label4')

mask_cat = MaskCategories(
colormap=mask_tools.generate_colormap(5))

return {
AnnotationType.label: label_cat,
AnnotationType.mask: mask_cat,
}

class DstExtractor(Extractor):
def __iter__(self):
return iter([
DatasetItem(id=1, annotations=[
Label(1),
Bbox(1, 2, 3, 4, label=0),
Mask(image=np.array([1]), label=1),

Polygon([1, 1, 2, 2, 3, 4], label=2),
PolyLine([1, 3, 4, 2, 5, 6], label=None)
]),
])

def categories(self):
label_cat = LabelCategories()
label_cat.add('label0')
label_cat.add('label9')
label_cat.add('label4')

mask_cat = MaskCategories(colormap={
k: v for k, v in mask_tools.generate_colormap(5).items()
if k in { 0, 1, 3, 4 }
})

return {
AnnotationType.label: label_cat,
AnnotationType.mask: mask_cat,
}

actual = transforms.RemapLabels(SrcExtractor(), mapping={
'label1': 'label9',
'label2': 'label0',
'label3': 'label9',
}, default='keep')

compare_datasets(self, DstExtractor(), actual)

def test_remap_labels_delete_unspecified(self):
class SrcExtractor(Extractor):
def __iter__(self):
return iter([ DatasetItem(id=1, annotations=[ Label(0) ]) ])

def categories(self):
label_cat = LabelCategories()
label_cat.add('label0')

return { AnnotationType.label: label_cat }

class DstExtractor(Extractor):
def __iter__(self):
return iter([ DatasetItem(id=1, annotations=[]) ])

def categories(self):
return { AnnotationType.label: LabelCategories() }

actual = transforms.RemapLabels(SrcExtractor(),
mapping={}, default='delete')

compare_datasets(self, DstExtractor(), actual)