Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add MultiTaskDataset to support multi-task training. #808

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/en/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ Base classes

.. autoclass:: MultiLabelDataset

Multi-Task Dataset
------------------

.. autoclass:: MultiTaskDataset

Dataset Wrappers
----------------

Expand Down
4 changes: 4 additions & 0 deletions docs/en/api/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,7 @@ ToTensor
Transpose
---------------------
.. autoclass:: Transpose

FormatMultiTaskLabels
---------------------
.. autoclass:: FormatMultiTaskLabels
4 changes: 3 additions & 1 deletion mmcls/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .imagenet21k import ImageNet21k
from .mnist import MNIST, FashionMNIST
from .multi_label import MultiLabelDataset
from .multi_task import MultiTaskDataset
from .samplers import DistributedSampler, RepeatAugSampler
from .voc import VOC

Expand All @@ -19,5 +20,6 @@
'VOC', 'MultiLabelDataset', 'build_dataloader', 'build_dataset',
'DistributedSampler', 'ConcatDataset', 'RepeatDataset',
'ClassBalancedDataset', 'DATASETS', 'PIPELINES', 'ImageNet21k', 'SAMPLERS',
'build_sampler', 'RepeatAugSampler', 'KFoldDataset', 'CUB', 'CustomDataset'
'build_sampler', 'RepeatAugSampler', 'KFoldDataset', 'CUB',
'CustomDataset', 'MultiTaskDataset'
]
24 changes: 19 additions & 5 deletions mmcls/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,24 @@ def evaluate(self,
Returns:
dict: evaluation results
"""
results = np.vstack(results)
gt_labels = self.get_gt_labels()
if indices is not None:
gt_labels = gt_labels[indices]

return self.evaluate_single_label(
results=results,
gt_labels=gt_labels,
metric=metric,
metric_options=metric_options,
logger=logger)

@staticmethod
def evaluate_single_label(results,
gt_labels,
metric='accuracy',
metric_options=None,
logger=None):
if metric_options is None:
metric_options = {'topk': (1, 5)}
if isinstance(metric, str):
Expand All @@ -154,11 +172,6 @@ def evaluate(self,
allowed_metrics = [
'accuracy', 'precision', 'recall', 'f1_score', 'support'
]
eval_results = {}
results = np.vstack(results)
gt_labels = self.get_gt_labels()
if indices is not None:
gt_labels = gt_labels[indices]
num_imgs = len(results)
assert len(gt_labels) == num_imgs, 'dataset testing results should '\
'be of the same length as gt_labels.'
Expand All @@ -171,6 +184,7 @@ def evaluate(self,
thrs = metric_options.get('thrs')
average_mode = metric_options.get('average_mode', 'macro')

eval_results = {}
if 'accuracy' in metrics:
if thrs is not None:
acc = accuracy(results, gt_labels, topk=topk, thrs=thrs)
Expand Down
22 changes: 18 additions & 4 deletions mmcls/datasets/multi_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,23 @@ def evaluate(self,
Returns:
dict: evaluation results
"""
results = np.vstack(results)
gt_labels = self.get_gt_labels()
if indices is not None:
gt_labels = gt_labels[indices]
return self.evaluate_multi_label(
results=results,
gt_labels=gt_labels,
metric=metric,
metric_options=metric_options,
logger=logger)

@staticmethod
def evaluate_multi_label(results,
gt_labels,
metric='mAP',
metric_options=None,
logger=None):
if metric_options is None or metric_options == {}:
metric_options = {'thr': 0.5}

Expand All @@ -53,10 +70,7 @@ def evaluate(self,
metrics = metric
allowed_metrics = ['mAP', 'CP', 'CR', 'CF1', 'OP', 'OR', 'OF1']
eval_results = {}
results = np.vstack(results)
gt_labels = self.get_gt_labels()
if indices is not None:
gt_labels = gt_labels[indices]

num_imgs = len(results)
assert len(gt_labels) == num_imgs, 'dataset testing results should '\
'be of the same length as gt_labels.'
Expand Down
Loading