Skip to content

Commit

Permalink
[Enhance] metainfo of dataset can be a generic dict-like Mapping (#1378)
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyyg committed Oct 8, 2023
1 parent 9cbe066 commit eb5834f
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 10 deletions.
23 changes: 13 additions & 10 deletions mmengine/dataset/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
import gc
import logging
import pickle
from collections.abc import Mapping
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union

import numpy as np
from torch.utils.data import Dataset

from mmengine.config import Config
from mmengine.fileio import join_path, list_from_file, load
from mmengine.logging import print_log
from mmengine.registry import TRANSFORMS
Expand Down Expand Up @@ -155,8 +157,8 @@ class BaseDataset(Dataset):
Args:
ann_file (str, optional): Annotation file path. Defaults to ''.
metainfo (dict, optional): Meta information for dataset, such as class
information. Defaults to None.
metainfo (Mapping or Config, optional): Meta information for
dataset, such as class information. Defaults to None.
data_root (str, optional): The root directory for ``data_prefix`` and
``ann_file``. Defaults to ''.
data_prefix (dict): Prefix for training data. Defaults to
Expand Down Expand Up @@ -213,7 +215,7 @@ class BaseDataset(Dataset):

def __init__(self,
ann_file: Optional[str] = '',
metainfo: Optional[dict] = None,
metainfo: Union[Mapping, Config, None] = None,
data_root: Optional[str] = '',
data_prefix: dict = dict(img_path=''),
filter_cfg: Optional[dict] = None,
Expand Down Expand Up @@ -472,13 +474,14 @@ def load_data_list(self) -> List[dict]:
return data_list

@classmethod
def _load_metainfo(cls, metainfo: dict = None) -> dict:
def _load_metainfo(cls,
metainfo: Union[Mapping, Config, None] = None) -> dict:
"""Collect meta information from the dictionary of meta.
Args:
metainfo (dict): Meta information dict. If ``metainfo``
contains existed filename, it will be parsed by
``list_from_file``.
metainfo (Mapping or Config, optional): Meta information dict.
If ``metainfo`` contains existed filename, it will be
parsed by ``list_from_file``.
Returns:
dict: Parsed meta information.
Expand All @@ -487,9 +490,9 @@ def _load_metainfo(cls, metainfo: dict = None) -> dict:
cls_metainfo = copy.deepcopy(cls.METAINFO)
if metainfo is None:
return cls_metainfo
if not isinstance(metainfo, dict):
raise TypeError(
f'metainfo should be a dict, but got {type(metainfo)}')
if not isinstance(metainfo, (Mapping, Config)):
raise TypeError('metainfo should be a Mapping or Config, '
f'but got {type(metainfo)}')

for k, v in metainfo.items():
if isinstance(v, str):
Expand Down
34 changes: 34 additions & 0 deletions tests/test_dataset/test_base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
import torch

from mmengine.config import Config, ConfigDict
from mmengine.dataset import (BaseDataset, ClassBalancedDataset, Compose,
ConcatDataset, RepeatDataset, force_full_init)
from mmengine.registry import DATASETS, TRANSFORMS
Expand Down Expand Up @@ -202,6 +203,39 @@ def test_meta(self):
task_name='new_task',
classes=('dog', ),
empty_list=[])

# test dataset.metainfo with passing metainfo as Config into
# self.base_dataset
metainfo = Config(dict(classes=('dog', ), task_name='new_task'))
dataset = BaseDataset(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img_path='imgs'),
ann_file='annotations/dummy_annotation.json',
metainfo=metainfo)
assert BaseDataset.METAINFO == dict(
dataset_type=dataset_type, classes=('dog', 'cat'))
assert dataset.metainfo == dict(
dataset_type=dataset_type,
task_name='new_task',
classes=('dog', ),
empty_list=[])

# test dataset.metainfo with passing metainfo as ConfigDict (Mapping)
# into self.base_dataset
metainfo = ConfigDict(dict(classes=('dog', ), task_name='new_task'))
dataset = BaseDataset(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img_path='imgs'),
ann_file='annotations/dummy_annotation.json',
metainfo=metainfo)
assert BaseDataset.METAINFO == dict(
dataset_type=dataset_type, classes=('dog', 'cat'))
assert dataset.metainfo == dict(
dataset_type=dataset_type,
task_name='new_task',
classes=('dog', ),
empty_list=[])

# reset `base_dataset.METAINFO`, the `dataset.metainfo` should not
# change
BaseDataset.METAINFO['classes'] = ('dog', 'cat', 'fish')
Expand Down

0 comments on commit eb5834f

Please sign in to comment.