Skip to content

Commit

Permalink
Add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mzr1996 committed Jan 14, 2022
1 parent 4c73017 commit 46809a7
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 3 deletions.
4 changes: 2 additions & 2 deletions mmcls/datasets/builder.py
Expand Up @@ -37,9 +37,9 @@ def build_dataset(cfg, default_args=None):
build_dataset(cfg['dataset'], default_args), cfg['oversample_thr'])
elif cfg['type'] == 'KFoldDataset':
cp_cfg = copy.deepcopy(cfg)
cp_cfg['dataset'] = build_dataset(cp_cfg['dataset'])
if cp_cfg.get('test_mode', None) is None:
cp_cfg['test_mode'] = (default_args or {}).get('test_mode', False)
cp_cfg['test_mode'] = (default_args or {}).pop('test_mode', False)
cp_cfg['dataset'] = build_dataset(cp_cfg['dataset'], default_args)
cp_cfg.pop('type')
dataset = KFoldDataset(**cp_cfg)
else:
Expand Down
152 changes: 151 additions & 1 deletion tests/test_data/test_builder.py
@@ -1,9 +1,14 @@
import os.path as osp
from copy import deepcopy
from unittest.mock import patch

import torch
from mmcv.utils import digit_version

from mmcls.datasets import build_dataloader
from mmcls.datasets import ImageNet, build_dataloader, build_dataset
from mmcls.datasets.dataset_wrappers import (ClassBalancedDataset,
ConcatDataset, KFoldDataset,
RepeatDataset)


class TestDataloaderBuilder():
Expand Down Expand Up @@ -119,3 +124,148 @@ def test_distributed(self, _):
expect = torch.tensor(
[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6][1::2])
assert all(torch.cat(list(iter(dataloader))) == expect)


class TestDatasetBuilder():

@classmethod
def setup_class(cls):
data_prefix = osp.join(osp.dirname(__file__), '../data/dataset')
cls.dataset_cfg = dict(
type='ImageNet',
data_prefix=data_prefix,
ann_file=osp.join(data_prefix, 'ann.txt'),
pipeline=[],
test_mode=False,
)

def test_normal_dataset(self):
# Test build
dataset = build_dataset(self.dataset_cfg)
assert isinstance(dataset, ImageNet)
assert dataset.test_mode == self.dataset_cfg['test_mode']

# Test default_args
dataset = build_dataset(self.dataset_cfg, {'test_mode': True})
assert dataset.test_mode == self.dataset_cfg['test_mode']

cp_cfg = deepcopy(self.dataset_cfg)
cp_cfg.pop('test_mode')
dataset = build_dataset(cp_cfg, {'test_mode': True})
assert dataset.test_mode

def test_concat_dataset(self):
# Test build
dataset = build_dataset([self.dataset_cfg, self.dataset_cfg])
assert isinstance(dataset, ConcatDataset)
assert dataset.datasets[0].test_mode == self.dataset_cfg['test_mode']

# Test default_args
dataset = build_dataset([self.dataset_cfg, self.dataset_cfg],
{'test_mode': True})
assert dataset.datasets[0].test_mode == self.dataset_cfg['test_mode']

cp_cfg = deepcopy(self.dataset_cfg)
cp_cfg.pop('test_mode')
dataset = build_dataset([cp_cfg, cp_cfg], {'test_mode': True})
assert dataset.datasets[0].test_mode

def test_repeat_dataset(self):
# Test build
dataset = build_dataset(
dict(type='RepeatDataset', dataset=self.dataset_cfg, times=3))
assert isinstance(dataset, RepeatDataset)
assert dataset.dataset.test_mode == self.dataset_cfg['test_mode']

# Test default_args
dataset = build_dataset(
dict(type='RepeatDataset', dataset=self.dataset_cfg, times=3),
{'test_mode': True})
assert dataset.dataset.test_mode == self.dataset_cfg['test_mode']

cp_cfg = deepcopy(self.dataset_cfg)
cp_cfg.pop('test_mode')
dataset = build_dataset(
dict(type='RepeatDataset', dataset=cp_cfg, times=3),
{'test_mode': True})
assert dataset.dataset.test_mode

def test_class_balance_dataset(self):
# Test build
dataset = build_dataset(
dict(
type='ClassBalancedDataset',
dataset=self.dataset_cfg,
oversample_thr=1.,
))
assert isinstance(dataset, ClassBalancedDataset)
assert dataset.dataset.test_mode == self.dataset_cfg['test_mode']

# Test default_args
dataset = build_dataset(
dict(
type='ClassBalancedDataset',
dataset=self.dataset_cfg,
oversample_thr=1.,
), {'test_mode': True})
assert dataset.dataset.test_mode == self.dataset_cfg['test_mode']

cp_cfg = deepcopy(self.dataset_cfg)
cp_cfg.pop('test_mode')
dataset = build_dataset(
dict(
type='ClassBalancedDataset',
dataset=cp_cfg,
oversample_thr=1.,
), {'test_mode': True})
assert dataset.dataset.test_mode

def test_kfold_dataset(self):
# Test build
dataset = build_dataset(
dict(
type='KFoldDataset',
dataset=self.dataset_cfg,
fold=0,
num_splits=5,
test_mode=False,
))
assert isinstance(dataset, KFoldDataset)
assert not dataset.test_mode
assert dataset.dataset.test_mode == self.dataset_cfg['test_mode']

# Test default_args
dataset = build_dataset(
dict(
type='KFoldDataset',
dataset=self.dataset_cfg,
fold=0,
num_splits=5,
test_mode=False,
),
default_args={
'test_mode': True,
'classes': [1, 2, 3]
})
assert not dataset.test_mode
assert dataset.dataset.test_mode == self.dataset_cfg['test_mode']
assert dataset.dataset.CLASSES == [1, 2, 3]

cp_cfg = deepcopy(self.dataset_cfg)
cp_cfg.pop('test_mode')
dataset = build_dataset(
dict(
type='KFoldDataset',
dataset=self.dataset_cfg,
fold=0,
num_splits=5,
),
default_args={
'test_mode': True,
'classes': [1, 2, 3]
})
# The test_mode in default_args will be passed to KFoldDataset
assert dataset.test_mode
assert not dataset.dataset.test_mode
# Other default_args will be passed to child dataset.
assert dataset.dataset.CLASSES == [1, 2, 3]

0 comments on commit 46809a7

Please sign in to comment.