From f1c467a967e1b88f545817007f859e0015a04b44 Mon Sep 17 00:00:00 2001 From: Vinnam Kim Date: Mon, 13 Mar 2023 11:02:59 +0900 Subject: [PATCH] Choose the top priority detect format for all directory depths (#839) * Make detect_dataset_format() return confidence too * Choose the highest prioritized one searched for all depths Signed-off-by: Kim, Vinnam --- CHANGELOG.md | 2 ++ datumaro/components/environment.py | 28 +++++++++++++------- datumaro/components/format_detection.py | 35 ++++++++++++++++++------- tests/unit/test_format_detection.py | 13 +++++++-- 4 files changed, 58 insertions(+), 20 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 22a42c525d..1fd9599b6c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 () - Fix CIFAR10 and 100 detect function () +- Choose the top priority detect format for all directory depths + () ## 24/02/2023 - Release v1.0.0 ### Added diff --git a/datumaro/components/environment.py b/datumaro/components/environment.py index 303f4d1272..1dc51af414 100644 --- a/datumaro/components/environment.py +++ b/datumaro/components/environment.py @@ -1,4 +1,4 @@ -# Copyright (C) 2020-2022 Intel Corporation +# Copyright (C) 2023 Intel Corporation # # SPDX-License-Identifier: MIT @@ -8,10 +8,15 @@ import os.path as osp from functools import partial from inspect import isclass -from typing import Callable, Dict, Generic, Iterable, Iterator, List, Optional, Type, TypeVar +from typing import Callable, Dict, Generic, Iterable, Iterator, List, Optional, Set, Type, TypeVar from datumaro.components.cli_plugin import CliPlugin, plugin_types -from datumaro.components.format_detection import RejectionReason, detect_dataset_format +from datumaro.components.format_detection import ( + DetectedFormat, + FormatDetectionConfidence, + RejectionReason, + detect_dataset_format, +) from datumaro.util.os_util import import_foreign_module, split_path T = TypeVar("T") @@ -251,7 +256,8 @@ def detect_dataset( rejection_callback: Optional[Callable[[str, RejectionReason, str], None]] = None, ) -> List[str]: ignore_dirs = {"__MSOSX", "__MACOSX"} - matched_formats = set() + all_matched_formats: Set[DetectedFormat] = set() + for _ in range(depth + 1): detected_formats = detect_dataset_format( ( @@ -262,17 +268,21 @@ def detect_dataset( rejection_callback=rejection_callback, ) - if detected_formats and len(detected_formats) == 1: - return detected_formats - elif detected_formats: - matched_formats |= set(detected_formats) + if detected_formats: + all_matched_formats |= set(detected_formats) paths = glob.glob(osp.join(path, "*")) path = "" if len(paths) != 1 else paths[0] if not osp.isdir(path) or osp.basename(path) in ignore_dirs: break - return list(matched_formats) + max_conf = ( + max(all_matched_formats).confidence + if len(all_matched_formats) > 0 + else FormatDetectionConfidence.NONE + ) + + return [str(format) for format in all_matched_formats if format.confidence == max_conf] def __reduce__(self): return (self.__class__, ()) diff --git a/datumaro/components/format_detection.py b/datumaro/components/format_detection.py index 9f92e98728..3aa487b6e3 100644 --- a/datumaro/components/format_detection.py +++ b/datumaro/components/format_detection.py @@ -1,4 +1,4 @@ -# Copyright (C) 2021-2022 Intel Corporation +# Copyright (C) 2023 Intel Corporation # # SPDX-License-Identifier: MIT @@ -7,6 +7,7 @@ import glob import logging as log import os.path as osp +from dataclasses import dataclass, field from enum import Enum, IntEnum, auto from io import BufferedReader from typing import ( @@ -33,6 +34,7 @@ class FormatDetectionConfidence(IntEnum): belonging to the detector's format. """ + NONE = 1 LOW = 10 """ The dataset seems to belong to the format, but the format is too loosely @@ -49,6 +51,18 @@ class FormatDetectionConfidence(IntEnum): # has explicit identification via magic numbers/files. +@dataclass(order=True, frozen=True) +class DetectedFormat: + confidence: FormatDetectionConfidence = field(compare=True) + name: str + + def __eq__(self, __o: "DetectedFormat") -> bool: + return self.name == __o.name + + def __str__(self) -> str: + return self.name + + # All confidence levels should be positive for a couple of reasons: # * It makes it possible to use 0 or a negative number as a special # value that is guaranteed to be less than any real value. @@ -457,7 +471,7 @@ def detect_dataset_format( path: str, *, rejection_callback: Optional[RejectionCallback] = None, -) -> Sequence[str]: +) -> Sequence[DetectedFormat]: """ Determines which format(s) the dataset at the specified path belongs to. @@ -495,7 +509,7 @@ def report_insufficient_confidence( ) max_confidence = 0 - matches = [] + matches: List[DetectedFormat] = [] for format_name, detector in formats: log.debug("Checking '%s' format...", format_name) @@ -512,22 +526,25 @@ def report_insufficient_confidence( # keep only matches with the highest confidence if new_confidence > max_confidence: for match in matches: - report_insufficient_confidence(match, format_name) + report_insufficient_confidence(match.name, format_name) - matches = [format_name] + matches = [DetectedFormat(new_confidence, format_name)] max_confidence = new_confidence elif new_confidence == max_confidence: - matches.append(format_name) + matches.append(DetectedFormat(new_confidence, format_name)) else: # new confidence is less than max - report_insufficient_confidence(format_name, matches[0]) + report_insufficient_confidence(format_name, matches[0].name) # TODO: This should be controlled by our priority logic. # However, some datasets' detect() are currently broken, # so that it is inevitable to introduce this. # We must revisit this after fixing detect(). - def _give_more_priority_to_with_subset_dirs(matches): + def _give_more_priority_to_with_subset_dirs(matches: List[DetectedFormat]): for idx, match in enumerate(matches): - if match + "_with_subset_dirs" in matches: + with_subset_dir_match = DetectedFormat( + match.confidence, match.name + "_with_subset_dirs" + ) + if with_subset_dir_match in matches: matches = matches.pop(idx) return True return False diff --git a/tests/unit/test_format_detection.py b/tests/unit/test_format_detection.py index 65cf25edb9..ef2e14be43 100644 --- a/tests/unit/test_format_detection.py +++ b/tests/unit/test_format_detection.py @@ -1,3 +1,7 @@ +# Copyright (C) 2023 Intel Corporation +# +# SPDX-License-Identifier: MIT + import os.path as osp from unittest import TestCase @@ -265,7 +269,9 @@ def rejection_callback(format, reason, message): formats, self._dataset_root, rejection_callback=rejection_callback ) - self.assertEqual(set(detected_datasets), {"bbb", "eee"}) + detected_dataset_names = [detected_dataset.name for detected_dataset in detected_datasets] + + self.assertEqual(set(detected_dataset_names), {"bbb", "eee"}) self.assertEqual(rejected_formats.keys(), {"aaa", "ccc", "ddd", "fff"}) @@ -285,4 +291,7 @@ def test_no_callback(self): ] detected_datasets = detect_dataset_format(formats, self._dataset_root) - self.assertEqual(detected_datasets, ["bbb"]) + + detected_dataset_names = [detected_dataset.name for detected_dataset in detected_datasets] + + self.assertEqual(detected_dataset_names, ["bbb"])