Skip to content

Commit

Permalink
Choose the top priority detect format for all directory depths (#839)
Browse files Browse the repository at this point in the history
* Make detect_dataset_format() return confidence too
* Choose the highest prioritized one searched for all depths

Signed-off-by: Kim, Vinnam <vinnam.kim@intel.com>
  • Loading branch information
vinnamkim committed Mar 13, 2023
1 parent 66580cb commit f1c467a
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 20 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Expand Up @@ -23,6 +23,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
(<https://github.com/openvinotoolkit/datumaro/pull/835>)
- Fix CIFAR10 and 100 detect function
(<https://github.com/openvinotoolkit/datumaro/pull/836>)
- Choose the top priority detect format for all directory depths
(<https://github.com/openvinotoolkit/datumaro/pull/839>)

## 24/02/2023 - Release v1.0.0
### Added
Expand Down
28 changes: 19 additions & 9 deletions datumaro/components/environment.py
@@ -1,4 +1,4 @@
# Copyright (C) 2020-2022 Intel Corporation
# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: MIT

Expand All @@ -8,10 +8,15 @@
import os.path as osp
from functools import partial
from inspect import isclass
from typing import Callable, Dict, Generic, Iterable, Iterator, List, Optional, Type, TypeVar
from typing import Callable, Dict, Generic, Iterable, Iterator, List, Optional, Set, Type, TypeVar

from datumaro.components.cli_plugin import CliPlugin, plugin_types
from datumaro.components.format_detection import RejectionReason, detect_dataset_format
from datumaro.components.format_detection import (
DetectedFormat,
FormatDetectionConfidence,
RejectionReason,
detect_dataset_format,
)
from datumaro.util.os_util import import_foreign_module, split_path

T = TypeVar("T")
Expand Down Expand Up @@ -251,7 +256,8 @@ def detect_dataset(
rejection_callback: Optional[Callable[[str, RejectionReason, str], None]] = None,
) -> List[str]:
ignore_dirs = {"__MSOSX", "__MACOSX"}
matched_formats = set()
all_matched_formats: Set[DetectedFormat] = set()

for _ in range(depth + 1):
detected_formats = detect_dataset_format(
(
Expand All @@ -262,17 +268,21 @@ def detect_dataset(
rejection_callback=rejection_callback,
)

if detected_formats and len(detected_formats) == 1:
return detected_formats
elif detected_formats:
matched_formats |= set(detected_formats)
if detected_formats:
all_matched_formats |= set(detected_formats)

paths = glob.glob(osp.join(path, "*"))
path = "" if len(paths) != 1 else paths[0]
if not osp.isdir(path) or osp.basename(path) in ignore_dirs:
break

return list(matched_formats)
max_conf = (
max(all_matched_formats).confidence
if len(all_matched_formats) > 0
else FormatDetectionConfidence.NONE
)

return [str(format) for format in all_matched_formats if format.confidence == max_conf]

def __reduce__(self):
return (self.__class__, ())
Expand Down
35 changes: 26 additions & 9 deletions datumaro/components/format_detection.py
@@ -1,4 +1,4 @@
# Copyright (C) 2021-2022 Intel Corporation
# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: MIT

Expand All @@ -7,6 +7,7 @@
import glob
import logging as log
import os.path as osp
from dataclasses import dataclass, field
from enum import Enum, IntEnum, auto
from io import BufferedReader
from typing import (
Expand All @@ -33,6 +34,7 @@ class FormatDetectionConfidence(IntEnum):
belonging to the detector's format.
"""

NONE = 1
LOW = 10
"""
The dataset seems to belong to the format, but the format is too loosely
Expand All @@ -49,6 +51,18 @@ class FormatDetectionConfidence(IntEnum):
# has explicit identification via magic numbers/files.


@dataclass(order=True, frozen=True)
class DetectedFormat:
confidence: FormatDetectionConfidence = field(compare=True)
name: str

def __eq__(self, __o: "DetectedFormat") -> bool:
return self.name == __o.name

def __str__(self) -> str:
return self.name


# All confidence levels should be positive for a couple of reasons:
# * It makes it possible to use 0 or a negative number as a special
# value that is guaranteed to be less than any real value.
Expand Down Expand Up @@ -457,7 +471,7 @@ def detect_dataset_format(
path: str,
*,
rejection_callback: Optional[RejectionCallback] = None,
) -> Sequence[str]:
) -> Sequence[DetectedFormat]:
"""
Determines which format(s) the dataset at the specified path belongs to.
Expand Down Expand Up @@ -495,7 +509,7 @@ def report_insufficient_confidence(
)

max_confidence = 0
matches = []
matches: List[DetectedFormat] = []

for format_name, detector in formats:
log.debug("Checking '%s' format...", format_name)
Expand All @@ -512,22 +526,25 @@ def report_insufficient_confidence(
# keep only matches with the highest confidence
if new_confidence > max_confidence:
for match in matches:
report_insufficient_confidence(match, format_name)
report_insufficient_confidence(match.name, format_name)

matches = [format_name]
matches = [DetectedFormat(new_confidence, format_name)]
max_confidence = new_confidence
elif new_confidence == max_confidence:
matches.append(format_name)
matches.append(DetectedFormat(new_confidence, format_name))
else: # new confidence is less than max
report_insufficient_confidence(format_name, matches[0])
report_insufficient_confidence(format_name, matches[0].name)

# TODO: This should be controlled by our priority logic.
# However, some datasets' detect() are currently broken,
# so that it is inevitable to introduce this.
# We must revisit this after fixing detect().
def _give_more_priority_to_with_subset_dirs(matches):
def _give_more_priority_to_with_subset_dirs(matches: List[DetectedFormat]):
for idx, match in enumerate(matches):
if match + "_with_subset_dirs" in matches:
with_subset_dir_match = DetectedFormat(
match.confidence, match.name + "_with_subset_dirs"
)
if with_subset_dir_match in matches:
matches = matches.pop(idx)
return True
return False
Expand Down
13 changes: 11 additions & 2 deletions tests/unit/test_format_detection.py
@@ -1,3 +1,7 @@
# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: MIT

import os.path as osp
from unittest import TestCase

Expand Down Expand Up @@ -265,7 +269,9 @@ def rejection_callback(format, reason, message):
formats, self._dataset_root, rejection_callback=rejection_callback
)

self.assertEqual(set(detected_datasets), {"bbb", "eee"})
detected_dataset_names = [detected_dataset.name for detected_dataset in detected_datasets]

self.assertEqual(set(detected_dataset_names), {"bbb", "eee"})

self.assertEqual(rejected_formats.keys(), {"aaa", "ccc", "ddd", "fff"})

Expand All @@ -285,4 +291,7 @@ def test_no_callback(self):
]

detected_datasets = detect_dataset_format(formats, self._dataset_root)
self.assertEqual(detected_datasets, ["bbb"])

detected_dataset_names = [detected_dataset.name for detected_dataset in detected_datasets]

self.assertEqual(detected_dataset_names, ["bbb"])

0 comments on commit f1c467a

Please sign in to comment.