Skip to content

Commit

Permalink
[Feature] Support Multi-task. (#1229)
Browse files Browse the repository at this point in the history
* unit test for multi_task_head

* [Feature] MultiTaskHead (#628, #481)

* [Fix] lint for multi_task_head

* [Feature] Add `MultiTaskDataset` to support multi-task training.

* Update MultiTaskClsHead

* Update docs

* [CI] Add test mim CI. (#879)

* [Fix] Remove duplicated wide-resnet metafile.

* [Feature] Support MPS device. (#894)

* [Feature] Support MPS device.

* Add `auto_select_device`

* Add unit tests

* [Fix] Fix Albu crash bug. (#918)

* Fix albu BUG: using albu will cause the label from array(x) to array([x]) and crash the trainning

* Fix common

* Using copy incase potential bug in multi-label tasks

* Improve coding

* Improve code logic

* Add unit test

* Fix typo

* Fix yapf

* Bump version to 0.23.2. (#937)

* [Improve] Use `forward_dummy` to calculate FLOPS. (#953)

* Update README

* [Docs] Fix typo for wrong reference. (#1036)

* [Doc] Fix typo in tutorial 2 (#1043)

* [Docs] Fix a typo in ImageClassifier (#1050)

* add mask to loss

* add another pipeline

* adpat the pipeline if there is no mask

* switch mask and task

* first version of multi data smaple

* fix problem with attribut by getattr

* rm img_label suffix, fix 'LabelData' object has no attribute 'gt_label'

* training  without evaluation

* first version work

* add others metrics

* delete evaluation from dataset

* fix linter

* fix linter

* multi metrics

* first version of test

* change evaluate metric

* Update tests/test_models/test_heads.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update tests/test_models/test_heads.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* add tests

* add test for multidatasample

* create a generic test

* create a generic test

* create a generic test

* change multi data sample

* correct test

* test

* add new test

* add test for dataset

* correct test

* correct test

* correct test

* correct test

* fix : #5

* run yapf

* fix linter

* fix linter

* fix linter

* fix isort

* fix isort

* fix docformmater

* fix docformmater

* fix linter

* fix linter

* fix data sample

* Update mmcls/structures/multi_task_data_sample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/structures/multi_task_data_sample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/structures/multi_task_data_sample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/structures/multi_task_data_sample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/structures/multi_task_data_sample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/structures/multi_task_data_sample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update tests/test_structures/test_datasample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/structures/multi_task_data_sample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update tests/test_structures/test_datasample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update tests/test_structures/test_datasample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* update data sample

* update head

* update head

* update multi data sample

* fix linter

* fix linter

* fix linter

* fix linter

* fix linter

* fix linter

* update head

* fix problem we don't  set pred or  gt

* fix problem we don't  set pred or  gt

* fix problem we don't  set pred or  gt

* fix linter

* fix : #2

* fix : linter

* update multi head

* fix linter

* fix linter

* update data sample

* update data sample

* fix ; linter

* update test

* test pipeline

* update pipeline

* update test

* update dataset

* update dataset

* fix linter

* fix linter

* update formatting

* add test for multi-task-eval

* update formatting

* fix linter

* update test

* update

* add test

* update metrics

* update metrics

* add doc for functions

* fix linter

* training for multitask 1.x

* fix linter

* run flake8

* run linter

* update test

* add mask in evaluation

* update metric doc

* update metric doc

* Update mmcls/evaluation/metrics/multi_task.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/evaluation/metrics/multi_task.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/evaluation/metrics/multi_task.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/evaluation/metrics/multi_task.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/evaluation/metrics/multi_task.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/evaluation/metrics/multi_task.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* update metric doc

* update metric doc

* Fix cannot import name MultiTaskDataSample

* fix test_datasets

* fix test_datasets

* fix linter

* add an example of multitask

* change name of configs dataset

* Refactor the multi-task support

* correct test and metric

* add test to multidatasample

* add test to multidatasample

* correct test

* correct metrics and clshead

* Update mmcls/models/heads/cls_head.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* update cls_head.py documentation

* lint

* lint

* fix: lint

* fix linter

* add eval mask

* fix documentation

* fix: single_label.py back to 1.x

* Update mmcls/models/heads/multi_task_head.py

Co-authored-by: Ma Zerun <mzr1996@163.com>

* Remove multi-task configs.

Co-authored-by: mzr1996 <mzr1996@163.com>
Co-authored-by: HinGwenWoong <peterhuang0323@qq.com>
Co-authored-by: Ming-Hsuan-Tu <alec.tu@acer.com>
Co-authored-by: Lei Lei <18294546+Crescent-Saturn@users.noreply.github.com>
Co-authored-by: WRH <12756472+wangruohui@users.noreply.github.com>
Co-authored-by: marouaneamz <maroineamil99@gmail.com>
Co-authored-by: marouane amzil <53240092+marouaneamz@users.noreply.github.com>
  • Loading branch information
8 people committed Dec 30, 2022
1 parent 5b266d9 commit bac181f
Show file tree
Hide file tree
Showing 19 changed files with 1,185 additions and 41 deletions.
3 changes: 2 additions & 1 deletion mmcls/datasets/__init__.py
Expand Up @@ -8,12 +8,13 @@
from .imagenet import ImageNet, ImageNet21k
from .mnist import MNIST, FashionMNIST
from .multi_label import MultiLabelDataset
from .multi_task import MultiTaskDataset
from .samplers import * # noqa: F401,F403
from .transforms import * # noqa: F401,F403
from .voc import VOC

__all__ = [
'BaseDataset', 'ImageNet', 'CIFAR10', 'CIFAR100', 'MNIST', 'FashionMNIST',
'VOC', 'build_dataset', 'ImageNet21k', 'KFoldDataset', 'CUB',
'CustomDataset', 'MultiLabelDataset'
'CustomDataset', 'MultiLabelDataset', 'MultiTaskDataset'
]
344 changes: 344 additions & 0 deletions mmcls/datasets/multi_task.py
@@ -0,0 +1,344 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os.path as osp
from os import PathLike
from typing import Optional, Sequence

import mmengine
from mmcv.transforms import Compose
from mmengine.fileio import FileClient

from .builder import DATASETS


def expanduser(path):
if isinstance(path, (str, PathLike)):
return osp.expanduser(path)
else:
return path


def isabs(uri):
return osp.isabs(uri) or ('://' in uri)


@DATASETS.register_module()
class MultiTaskDataset:
"""Custom dataset for multi-task dataset.
To use the dataset, please generate and provide an annotation file in the
below format:
.. code-block:: json
{
"metainfo": {
"tasks":
[
'gender'
'wear'
]
},
"data_list": [
{
"img_path": "a.jpg",
gt_label:{
"gender": 0,
"wear": [1, 0, 1, 0]
}
},
{
"img_path": "b.jpg",
gt_label:{
"gender": 1,
"wear": [1, 0, 1, 0]
}
}
]
}
Assume we put our dataset in the ``data/mydataset`` folder in the
repository and organize it as the below format: ::
mmclassification/
└── data
└── mydataset
├── annotation
│   ├── train.json
│   ├── test.json
│   └── val.json
├── train
│   ├── a.jpg
│   └── ...
├── test
│   ├── b.jpg
│   └── ...
└── val
├── c.jpg
└── ...
We can use the below config to build datasets:
.. code:: python
>>> from mmcls.datasets import build_dataset
>>> train_cfg = dict(
... type="MultiTaskDataset",
... ann_file="annotation/train.json",
... data_root="data/mydataset",
... # The `img_path` field in the train annotation file is relative
... # to the `train` folder.
... data_prefix='train',
... )
>>> train_dataset = build_dataset(train_cfg)
Or we can put all files in the same folder: ::
mmclassification/
└── data
└── mydataset
├── train.json
├── test.json
├── val.json
├── a.jpg
├── b.jpg
├── c.jpg
└── ...
And we can use the below config to build datasets:
.. code:: python
>>> from mmcls.datasets import build_dataset
>>> train_cfg = dict(
... type="MultiTaskDataset",
... ann_file="train.json",
... data_root="data/mydataset",
... # the `data_prefix` is not required since all paths are
... # relative to the `data_root`.
... )
>>> train_dataset = build_dataset(train_cfg)
Args:
ann_file (str): The annotation file path. It can be either absolute
path or relative path to the ``data_root``.
metainfo (dict, optional): The extra meta information. It should be
a dict with the same format as the ``"metainfo"`` field in the
annotation file. Defaults to None.
data_root (str, optional): The root path of the data directory. It's
the prefix of the ``data_prefix`` and the ``ann_file``. And it can
be a remote path like "s3://openmmlab/xxx/". Defaults to None.
data_prefix (str, optional): The base folder relative to the
``data_root`` for the ``"img_path"`` field in the annotation file.
Defaults to None.
pipeline (Sequence[dict]): A list of dict, where each element
represents a operation defined in :mod:`mmcls.datasets.pipelines`.
Defaults to an empty tuple.
test_mode (bool): in train mode or test mode. Defaults to False.
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmengine.fileio.FileClient` for details.
If None, automatically inference from the ``data_root``.
Defaults to None.
"""
METAINFO = dict()

def __init__(self,
ann_file: str,
metainfo: Optional[dict] = None,
data_root: Optional[str] = None,
data_prefix: Optional[str] = None,
pipeline: Sequence = (),
test_mode: bool = False,
file_client_args: Optional[dict] = None):

self.data_root = expanduser(data_root)

# Inference the file client
if self.data_root is not None:
file_client = FileClient.infer_client(
file_client_args, uri=self.data_root)
else:
file_client = FileClient(file_client_args)
self.file_client: FileClient = file_client

self.ann_file = self._join_root(expanduser(ann_file))
self.data_prefix = self._join_root(data_prefix)

self.test_mode = test_mode
self.pipeline = Compose(pipeline)
self.data_list = self.load_data_list(self.ann_file, metainfo)

def _join_root(self, path):
"""Join ``self.data_root`` with the specified path.
If the path is an absolute path, just return the path. And if the
path is None, return ``self.data_root``.
Examples:
>>> self.data_root = 'a/b/c'
>>> self._join_root('d/e/')
'a/b/c/d/e'
>>> self._join_root('https://openmmlab.com')
'https://openmmlab.com'
>>> self._join_root(None)
'a/b/c'
"""
if path is None:
return self.data_root
if isabs(path):
return path

joined_path = self.file_client.join_path(self.data_root, path)
return joined_path

@classmethod
def _get_meta_info(cls, in_metainfo: dict = None) -> dict:
"""Collect meta information from the dictionary of meta.
Args:
in_metainfo (dict): Meta information dict.
Returns:
dict: Parsed meta information.
"""
# `cls.METAINFO` will be overwritten by in_meta
metainfo = copy.deepcopy(cls.METAINFO)
if in_metainfo is None:
return metainfo

metainfo.update(in_metainfo)

return metainfo

def load_data_list(self, ann_file, metainfo_override=None):
"""Load annotations from an annotation file.
Args:
ann_file (str): Absolute annotation file path if ``self.root=None``
or relative path if ``self.root=/path/to/data/``.
Returns:
list[dict]: A list of annotation.
"""
annotations = mmengine.load(ann_file)
if not isinstance(annotations, dict):
raise TypeError(f'The annotations loaded from annotation file '
f'should be a dict, but got {type(annotations)}!')
if 'data_list' not in annotations:
raise ValueError('The annotation file must have the `data_list` '
'field.')
metainfo = annotations.get('metainfo', {})
raw_data_list = annotations['data_list']

# Set meta information.
assert isinstance(metainfo, dict), 'The `metainfo` field in the '\
f'annotation file should be a dict, but got {type(metainfo)}'
if metainfo_override is not None:
assert isinstance(metainfo_override, dict), 'The `metainfo` ' \
f'argument should be a dict, but got {type(metainfo_override)}'
metainfo.update(metainfo_override)
self._metainfo = self._get_meta_info(metainfo)

data_list = []
for i, raw_data in enumerate(raw_data_list):
try:
data_list.append(self.parse_data_info(raw_data))
except AssertionError as e:
raise RuntimeError(
f'The format check fails during parse the item {i} of '
f'the annotation file with error: {e}')
return data_list

def parse_data_info(self, raw_data):
"""Parse raw annotation to target format.
This method will return a dict which contains the data information of a
sample.
Args:
raw_data (dict): Raw data information load from ``ann_file``
Returns:
dict: Parsed annotation.
"""
assert isinstance(raw_data, dict), \
f'The item should be a dict, but got {type(raw_data)}'
assert 'img_path' in raw_data, \
"The item doesn't have `img_path` field."
data = dict(
img_path=self._join_root(raw_data['img_path']),
gt_label=raw_data['gt_label'],
)
return data

@property
def metainfo(self) -> dict:
"""Get meta information of dataset.
Returns:
dict: meta information collected from ``cls.METAINFO``,
annotation file and metainfo argument during instantiation.
"""
return copy.deepcopy(self._metainfo)

def prepare_data(self, idx):
"""Get data processed by ``self.pipeline``.
Args:
idx (int): The index of ``data_info``.
Returns:
Any: Depends on ``self.pipeline``.
"""
results = copy.deepcopy(self.data_list[idx])
return self.pipeline(results)

def __len__(self):
"""Get the length of the whole dataset.
Returns:
int: The length of filtered dataset.
"""
return len(self.data_list)

def __getitem__(self, idx):
"""Get the idx-th image and data information of dataset after
``self.pipeline``.
Args:
idx (int): The index of of the data.
Returns:
dict: The idx-th image and data information after
``self.pipeline``.
"""
return self.prepare_data(idx)

def __repr__(self):
"""Print the basic information of the dataset.
Returns:
str: Formatted string.
"""
head = 'Dataset ' + self.__class__.__name__
body = [f'Number of samples: \t{self.__len__()}']
if self.data_root is not None:
body.append(f'Root location: \t{self.data_root}')
body.append(f'Annotation file: \t{self.ann_file}')
if self.data_prefix is not None:
body.append(f'Prefix of images: \t{self.data_prefix}')
# -------------------- extra repr --------------------
tasks = self.metainfo['tasks']
body.append(f'For {len(tasks)} tasks')
for task in tasks:
body.append(f' {task} ')
# ----------------------------------------------------

if len(self.pipeline.transforms) > 0:
body.append('With transforms:')
for t in self.pipeline.transforms:
body.append(f' {t}')

lines = [head] + [' ' * 4 + line for line in body]
return '\n'.join(lines)
6 changes: 4 additions & 2 deletions mmcls/datasets/transforms/__init__.py
Expand Up @@ -3,7 +3,8 @@
Brightness, ColorTransform, Contrast, Cutout,
Equalize, Invert, Posterize, RandAugment, Rotate,
Sharpness, Shear, Solarize, SolarizeAdd, Translate)
from .formatting import Collect, PackClsInputs, ToNumpy, ToPIL, Transpose
from .formatting import (Collect, PackClsInputs, PackMultiTaskInputs, ToNumpy,
ToPIL, Transpose)
from .processing import (Albumentations, ColorJitter, EfficientNetCenterCrop,
EfficientNetRandomCrop, Lighting, RandomCrop,
RandomErasing, RandomResizedCrop, ResizeEdge)
Expand All @@ -15,5 +16,6 @@
'Contrast', 'Brightness', 'Sharpness', 'AutoAugment', 'SolarizeAdd',
'Cutout', 'RandAugment', 'Lighting', 'ColorJitter', 'RandomErasing',
'PackClsInputs', 'Albumentations', 'EfficientNetRandomCrop',
'EfficientNetCenterCrop', 'ResizeEdge', 'BaseAugTransform'
'EfficientNetCenterCrop', 'ResizeEdge', 'BaseAugTransform',
'PackMultiTaskInputs'
]

0 comments on commit bac181f

Please sign in to comment.