Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhance] metainfo of dataset can be a generic dict-like Mapping #1378

Merged
merged 3 commits into from
Oct 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
23 changes: 13 additions & 10 deletions mmengine/dataset/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
import gc
import logging
import pickle
from collections.abc import Mapping
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union

import numpy as np
from torch.utils.data import Dataset

from mmengine.config import Config
from mmengine.fileio import join_path, list_from_file, load
from mmengine.logging import print_log
from mmengine.registry import TRANSFORMS
Expand Down Expand Up @@ -155,8 +157,8 @@ class BaseDataset(Dataset):

Args:
ann_file (str, optional): Annotation file path. Defaults to ''.
metainfo (dict, optional): Meta information for dataset, such as class
information. Defaults to None.
metainfo (Mapping or Config, optional): Meta information for
dataset, such as class information. Defaults to None.
data_root (str, optional): The root directory for ``data_prefix`` and
``ann_file``. Defaults to ''.
data_prefix (dict): Prefix for training data. Defaults to
Expand Down Expand Up @@ -213,7 +215,7 @@ class BaseDataset(Dataset):

def __init__(self,
ann_file: Optional[str] = '',
metainfo: Optional[dict] = None,
metainfo: Union[Mapping, Config, None] = None,
data_root: Optional[str] = '',
data_prefix: dict = dict(img_path=''),
filter_cfg: Optional[dict] = None,
Expand Down Expand Up @@ -472,13 +474,14 @@ def load_data_list(self) -> List[dict]:
return data_list

@classmethod
def _load_metainfo(cls, metainfo: dict = None) -> dict:
def _load_metainfo(cls,
metainfo: Union[Mapping, Config, None] = None) -> dict:
"""Collect meta information from the dictionary of meta.

Args:
metainfo (dict): Meta information dict. If ``metainfo``
contains existed filename, it will be parsed by
``list_from_file``.
metainfo (Mapping or Config, optional): Meta information dict.
If ``metainfo`` contains existed filename, it will be
parsed by ``list_from_file``.

Returns:
dict: Parsed meta information.
Expand All @@ -487,9 +490,9 @@ def _load_metainfo(cls, metainfo: dict = None) -> dict:
cls_metainfo = copy.deepcopy(cls.METAINFO)
if metainfo is None:
return cls_metainfo
if not isinstance(metainfo, dict):
raise TypeError(
f'metainfo should be a dict, but got {type(metainfo)}')
if not isinstance(metainfo, (Mapping, Config)):
raise TypeError('metainfo should be a Mapping or Config, '
f'but got {type(metainfo)}')

for k, v in metainfo.items():
if isinstance(v, str):
Expand Down
34 changes: 34 additions & 0 deletions tests/test_dataset/test_base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
import torch

from mmengine.config import Config, ConfigDict
from mmengine.dataset import (BaseDataset, ClassBalancedDataset, Compose,
ConcatDataset, RepeatDataset, force_full_init)
from mmengine.registry import DATASETS, TRANSFORMS
Expand Down Expand Up @@ -202,6 +203,39 @@ def test_meta(self):
task_name='new_task',
classes=('dog', ),
empty_list=[])

# test dataset.metainfo with passing metainfo as Config into
# self.base_dataset
metainfo = Config(dict(classes=('dog', ), task_name='new_task'))
dataset = BaseDataset(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img_path='imgs'),
ann_file='annotations/dummy_annotation.json',
metainfo=metainfo)
assert BaseDataset.METAINFO == dict(
dataset_type=dataset_type, classes=('dog', 'cat'))
assert dataset.metainfo == dict(
dataset_type=dataset_type,
task_name='new_task',
classes=('dog', ),
empty_list=[])

# test dataset.metainfo with passing metainfo as ConfigDict (Mapping)
# into self.base_dataset
metainfo = ConfigDict(dict(classes=('dog', ), task_name='new_task'))
dataset = BaseDataset(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img_path='imgs'),
ann_file='annotations/dummy_annotation.json',
metainfo=metainfo)
assert BaseDataset.METAINFO == dict(
dataset_type=dataset_type, classes=('dog', 'cat'))
assert dataset.metainfo == dict(
dataset_type=dataset_type,
task_name='new_task',
classes=('dog', ),
empty_list=[])

# reset `base_dataset.METAINFO`, the `dataset.metainfo` should not
# change
BaseDataset.METAINFO['classes'] = ('dog', 'cat', 'fish')
Expand Down