Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
Signed-off-by: Kim, Vinnam <vinnam.kim@intel.com>
  • Loading branch information
vinnamkim committed Jul 17, 2023
1 parent bcaca42 commit 448add1
Show file tree
Hide file tree
Showing 10 changed files with 126 additions and 31 deletions.
1 change: 1 addition & 0 deletions src/datumaro/components/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,7 @@ def export(

assert "ctx" not in kwargs
exporter_kwargs = copy(kwargs)
exporter_kwargs["stream"] = self._stream
exporter_kwargs["ctx"] = ExportContext(
progress_reporter=progress_reporter, error_policy=error_policy
)
Expand Down
12 changes: 12 additions & 0 deletions src/datumaro/components/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def __init__(
default_image_ext: Optional[str] = None,
save_dataset_meta: bool = False,
save_hashkey_meta: bool = False,
stream: bool = False,
ctx: Optional[ExportContext] = None,
):
default_image_ext = default_image_ext or self.DEFAULT_IMAGE_EXT
Expand Down Expand Up @@ -222,6 +223,12 @@ def __init__(
else:
self._patch = None

if stream and not self.can_stream:
raise DatasetExportError(
f"{self.__class__.__name__} cannot export a dataset in a stream manner"
)
self._stream = stream

self._ctx: ExportContext = ctx or NullExportContext()

def _find_image_ext(self, item: Union[DatasetItem, Image]):
Expand Down Expand Up @@ -299,6 +306,11 @@ def _check_hash_key_existence(self, item):
self._save_hashkey_meta = True
return

@property
def can_stream(self) -> bool:
"""Flag to indicate whether the exporter can export the dataset in a stream manner or not."""
return False


# TODO: Currently, ExportContextComponent is introduced only for Datumaro and DatumaroBinary format
# for multi-processing. We need to propagate this to everywhere in Datumaro 1.2.0
Expand Down
9 changes: 7 additions & 2 deletions src/datumaro/plugins/data_formats/arrow/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,13 @@ def __init__(
num_shards: int = 1,
max_shard_size: Optional[int] = None,
):
super().__init__(context, "", export_context)
super().__init__(
context=context,
subset=subset,
ann_file="",
export_context=export_context,
)
self._schema = deepcopy(DatumaroArrow.SCHEMA)
self._subset = subset
self._writers = []
self._fnames = []
self._max_chunk_size = max_chunk_size
Expand Down Expand Up @@ -370,6 +374,7 @@ def __init__(
num_shards: int = 1,
max_shard_size: Optional[int] = None,
max_chunk_size: int = 1000,
**kwargs,
):
super().__init__(
extractor=extractor,
Expand Down
86 changes: 72 additions & 14 deletions src/datumaro/plugins/data_formats/datumaro/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,17 @@

# pylint: disable=no-self-use

import json
import os
import os.path as osp
import shutil
from contextlib import contextmanager
from multiprocessing.pool import Pool
from typing import Optional
from typing import Dict, Optional

import numpy as np
import pycocotools.mask as mask_utils
from json_stream.writer import streamable_dict, streamable_list

from datumaro.components.annotation import (
Annotation,
Expand Down Expand Up @@ -43,8 +45,15 @@


class _SubsetWriter:
def __init__(self, context: Exporter, ann_file: str, export_context: ExportContextComponent):
def __init__(
self,
context: Exporter,
subset: str,
ann_file: str,
export_context: ExportContextComponent,
):
self._context = context
self._subset = subset

self._data = {
"dm_format_version": DATUMARO_FORMAT_VERSION,
Expand Down Expand Up @@ -122,7 +131,10 @@ def context_save_media(
else:
raise NotImplementedError

def add_item(self, item: DatasetItem, *args, **kwargs):
def add_item(self, item: DatasetItem, *args, **kwargs) -> None:
self.items.append(self._gen_item_desc(item))

def _gen_item_desc(self, item: DatasetItem, *args, **kwargs) -> Dict:
annotations = []
item_desc = {
"id": item.id,
Expand Down Expand Up @@ -155,8 +167,6 @@ def add_item(self, item: DatasetItem, *args, **kwargs):
elif isinstance(item.media, MediaElement):
item_desc["media"] = {"path": getattr(item.media, "path", None)}

self.items.append(item_desc)

for ann in item.annotations:
if isinstance(ann, Label):
converted_ann = self._convert_label_object(ann)
Expand All @@ -182,6 +192,8 @@ def add_item(self, item: DatasetItem, *args, **kwargs):
raise NotImplementedError()
annotations.append(converted_ann)

return item_desc

def add_infos(self, infos):
self._data["infos"].update(infos)

Expand Down Expand Up @@ -367,6 +379,39 @@ def _convert_points_categories(self, obj):
return converted


class _StreamSubsetWriter(_SubsetWriter):
def __init__(
self,
context: Exporter,
subset: str,
ann_file: str,
export_context: ExportContextComponent,
):
super().__init__(context, subset, ann_file, export_context)

def write(self, *args, **kwargs):
@streamable_list
def _item_list():
subset = self._context._extractor.get_subset(self._subset)
for item in subset:
yield self._gen_item_desc(item)

@streamable_dict
def _data():
yield "dm_format_version", self._data["dm_format_version"]
yield "media_type", self._data["media_type"]
yield "infos", self._data["infos"]
yield "categories", self._data["categories"]
yield "items", _item_list()

with open(self.ann_file, "w", encoding="utf-8") as fp:
json.dump(_data(), fp)

def is_empty(self):
# TODO: Force empty to be False, but it should be fixed with refactoring `_SubsetWriter`.`
return False


class DatumaroExporter(Exporter):
DEFAULT_IMAGE_EXT = DatumaroPath.IMAGE_EXT
PATH_CLS = DatumaroPath
Expand All @@ -387,10 +432,20 @@ def create_writer(
default_image_ext=self._default_image_ext,
)

return _SubsetWriter(
context=self,
ann_file=osp.join(self._annotations_dir, subset + self.PATH_CLS.ANNOTATION_EXT),
export_context=export_context,
return (
_SubsetWriter(
context=self,
subset=subset,
ann_file=osp.join(self._annotations_dir, subset + self.PATH_CLS.ANNOTATION_EXT),
export_context=export_context,
)
if not self._stream
else _StreamSubsetWriter(
context=self,
subset=subset,
ann_file=osp.join(self._annotations_dir, subset + self.PATH_CLS.ANNOTATION_EXT),
export_context=export_context,
)
)

def _apply_impl(self, pool: Optional[Pool] = None, *args, **kwargs):
Expand Down Expand Up @@ -419,11 +474,10 @@ def _apply_impl(self, pool: Optional[Pool] = None, *args, **kwargs):
writer.add_infos(self._extractor.infos())
writer.add_categories(self._extractor.categories())

for item in self._extractor:
subset = item.subset or DEFAULT_SUBSET_NAME
writers[subset].add_item(item, pool)

self._check_hash_key_existence(item)
if not self._stream:
for item in self._extractor:
subset = item.subset or DEFAULT_SUBSET_NAME
writers[subset].add_item(item, pool)

for subset, writer in writers.items():
if self._patch and subset in self._patch.updated_subsets and writer.is_empty():
Expand Down Expand Up @@ -469,3 +523,7 @@ def patch(cls, dataset, patch, save_dir, **kwargs):
related_images_path = osp.join(save_dir, cls.PATH_CLS.IMAGES_DIR, item.subset, item.id)
if osp.isdir(related_images_path):
shutil.rmtree(related_images_path)

@property
def can_stream(self) -> bool:
return True
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,14 @@ class _SubsetWriter(__SubsetWriter):
def __init__(
self,
context: Exporter,
subset: str,
ann_file: str,
export_context: ExportContextComponent,
secret_key_file: str,
no_media_encryption: bool = False,
max_blob_size: int = DatumaroBinaryPath.MAX_BLOB_SIZE,
):
super().__init__(context, ann_file, export_context)
super().__init__(context, subset, ann_file, export_context)
self._crypter = self.export_context.crypter
self.secret_key_file = secret_key_file

Expand Down Expand Up @@ -248,6 +249,7 @@ def __init__(
encryption: bool = False,
num_workers: int = 0,
max_blob_size: int = DatumaroBinaryPath.MAX_BLOB_SIZE,
**kwargs,
):
"""
Parameters
Expand Down Expand Up @@ -308,6 +310,7 @@ def create_writer(self, subset: str, images_dir: str, pcd_dir: str) -> _SubsetWr

return _SubsetWriter(
context=self,
subset=subset,
ann_file=osp.join(self._annotations_dir, subset + self.PATH_CLS.ANNOTATION_EXT),
export_context=export_context,
secret_key_file=osp.join(self._save_dir, self.PATH_CLS.SECRET_KEY_FILE),
Expand Down
4 changes: 4 additions & 0 deletions src/datumaro/plugins/data_formats/voc/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,6 +795,10 @@ def patch(cls, dataset, patch, save_dir, **kwargs):
if osp.isfile(path):
os.unlink(path)

@property
def can_stream(self) -> bool:
return True


class VocClassificationExporter(VocExporter):
def __init__(self, *args, **kwargs):
Expand Down
4 changes: 4 additions & 0 deletions src/datumaro/plugins/data_formats/yolo/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,10 @@ def patch(cls, dataset, patch, save_dir, **kwargs):
if osp.isfile(ann_path):
os.remove(ann_path)

@property
def can_stream(self) -> bool:
return True


class YoloUltralyticsExporter(YoloExporter):
allowed_subset_names = {"train", "val", "test"}
Expand Down
8 changes: 6 additions & 2 deletions tests/unit/data_formats/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,15 @@ def test_can_export_and_import(

helper_tc = request.getfixturevalue("helper_tc")

stream = True if dataset_cls == StreamDataset else False
exporter.convert(
fxt_expected_dataset, save_dir=test_dir, save_media=True, **fxt_export_kwargs
fxt_expected_dataset,
save_dir=test_dir,
save_media=True,
stream=stream,
**fxt_export_kwargs,
)
dataset = dataset_cls.import_from(test_dir, importer.NAME, **fxt_import_kwargs)
stream = True if dataset_cls == StreamDataset else False
check_is_stream(dataset, stream)

compare_datasets(helper_tc, fxt_expected_dataset, dataset, require_media=True)
2 changes: 1 addition & 1 deletion tests/unit/data_formats/datumaro/test_datumaro_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def test_can_save_and_load(
self._test_save_and_load(
helper_tc,
fxt_dataset,
partial(self.exporter.convert, save_media=True, **fxt_export_kwargs),
partial(self.exporter.convert, save_media=True, stream=stream, **fxt_export_kwargs),
test_dir,
compare=compare,
require_media=require_media,
Expand Down
26 changes: 15 additions & 11 deletions tests/unit/data_formats/test_yolo_strict_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_can_save_and_load(self, dataset_cls, is_stream, test_dir, helper_tc):
categories=["label_" + str(i) for i in range(10)],
)

YoloExporter.convert(source_dataset, test_dir, save_media=True)
YoloExporter.convert(source_dataset, test_dir, save_media=True, stream=is_stream)
parsed_dataset = dataset_cls.import_from(test_dir, "yolo")
assert parsed_dataset.is_stream == is_stream

Expand All @@ -95,7 +95,7 @@ def test_can_save_dataset_with_image_info(self, dataset_cls, is_stream, test_dir
categories=["label_" + str(i) for i in range(10)],
)

YoloExporter.convert(source_dataset, test_dir)
YoloExporter.convert(source_dataset, test_dir, stream=is_stream)

save_image(
osp.join(test_dir, "obj_train_data", "1.jpg"), np.ones((10, 15, 3))
Expand Down Expand Up @@ -125,7 +125,7 @@ def test_can_load_dataset_with_exact_image_info(
categories=["label_" + str(i) for i in range(10)],
)

YoloExporter.convert(source_dataset, test_dir)
YoloExporter.convert(source_dataset, test_dir, stream=is_stream)

parsed_dataset = dataset_cls.import_from(test_dir, "yolo", image_info={"1": (10, 15)})
assert parsed_dataset.is_stream == is_stream
Expand All @@ -152,7 +152,7 @@ def test_can_save_dataset_with_cyrillic_and_spaces_in_filename(
categories=["label_" + str(i) for i in range(10)],
)

YoloExporter.convert(source_dataset, test_dir, save_media=True)
YoloExporter.convert(source_dataset, test_dir, save_media=True, stream=is_stream)
parsed_dataset = dataset_cls.import_from(test_dir, "yolo")
assert parsed_dataset.is_stream == is_stream

Expand All @@ -177,7 +177,7 @@ def test_relative_paths(self, dataset_cls, is_stream, save_media, test_dir, help
categories=[],
)

YoloExporter.convert(source_dataset, test_dir, save_media=save_media)
YoloExporter.convert(source_dataset, test_dir, save_media=save_media, stream=is_stream)
parsed_dataset = dataset_cls.import_from(test_dir, "yolo")
assert parsed_dataset.is_stream == is_stream

Expand All @@ -204,7 +204,7 @@ def test_can_save_and_load_image_with_arbitrary_extension(
categories=[],
)

YoloExporter.convert(dataset, test_dir, save_media=True)
YoloExporter.convert(dataset, test_dir, save_media=True, stream=is_stream)
parsed_dataset = dataset_cls.import_from(test_dir, "yolo")
assert parsed_dataset.is_stream == is_stream

Expand All @@ -231,11 +231,11 @@ def test_inplace_save_writes_only_updated_data(
],
categories=[],
)
dataset.export(test_dir, "yolo", save_media=True)
dataset.export(test_dir, "yolo", save_media=True, stream=is_stream)

dataset.put(DatasetItem(2, subset="train", media=Image.from_numpy(data=np.ones((3, 2, 3)))))
dataset.remove(3, "valid")
dataset.save(save_media=True)
dataset.save(save_media=True, stream=is_stream)

assert {"1.txt", "2.txt", "1.jpg", "2.jpg"} == set(
os.listdir(osp.join(test_dir, "obj_train_data"))
Expand Down Expand Up @@ -284,7 +284,9 @@ def test_can_save_and_load_with_meta_file(self, dataset_cls, is_stream, test_dir
categories=["label_" + str(i) for i in range(10)],
)

YoloExporter.convert(source_dataset, test_dir, save_media=True, save_dataset_meta=True)
YoloExporter.convert(
source_dataset, test_dir, save_media=True, save_dataset_meta=True, stream=is_stream
)
parsed_dataset = dataset_cls.import_from(test_dir, "yolo")
assert parsed_dataset.is_stream == is_stream

Expand All @@ -311,7 +313,7 @@ def test_can_save_and_load_with_custom_subset_name(
categories=["label_" + str(i) for i in range(10)],
)

YoloExporter.convert(source_dataset, test_dir, save_media=True)
YoloExporter.convert(source_dataset, test_dir, save_media=True, stream=is_stream)
parsed_dataset = dataset_cls.import_from(test_dir, "yolo")
assert parsed_dataset.is_stream == is_stream

Expand Down Expand Up @@ -354,7 +356,9 @@ def test_can_save_and_load_without_path_prefix(
categories=["a", "b"],
)

YoloExporter.convert(source_dataset, test_dir, save_media=True, add_path_prefix=False)
YoloExporter.convert(
source_dataset, test_dir, save_media=True, add_path_prefix=False, stream=is_stream
)
parsed_dataset = dataset_cls.import_from(test_dir, "yolo")
assert parsed_dataset.is_stream == is_stream

Expand Down

0 comments on commit 448add1

Please sign in to comment.