From bac181f39302cce6f54885fd50a4b7e1fcc1cdd9 Mon Sep 17 00:00:00 2001 From: Colle Date: Fri, 30 Dec 2022 03:36:00 +0100 Subject: [PATCH] [Feature] Support Multi-task. (#1229) * unit test for multi_task_head * [Feature] MultiTaskHead (#628, #481) * [Fix] lint for multi_task_head * [Feature] Add `MultiTaskDataset` to support multi-task training. * Update MultiTaskClsHead * Update docs * [CI] Add test mim CI. (#879) * [Fix] Remove duplicated wide-resnet metafile. * [Feature] Support MPS device. (#894) * [Feature] Support MPS device. * Add `auto_select_device` * Add unit tests * [Fix] Fix Albu crash bug. (#918) * Fix albu BUG: using albu will cause the label from array(x) to array([x]) and crash the trainning * Fix common * Using copy incase potential bug in multi-label tasks * Improve coding * Improve code logic * Add unit test * Fix typo * Fix yapf * Bump version to 0.23.2. (#937) * [Improve] Use `forward_dummy` to calculate FLOPS. (#953) * Update README * [Docs] Fix typo for wrong reference. (#1036) * [Doc] Fix typo in tutorial 2 (#1043) * [Docs] Fix a typo in ImageClassifier (#1050) * add mask to loss * add another pipeline * adpat the pipeline if there is no mask * switch mask and task * first version of multi data smaple * fix problem with attribut by getattr * rm img_label suffix, fix 'LabelData' object has no attribute 'gt_label' * training without evaluation * first version work * add others metrics * delete evaluation from dataset * fix linter * fix linter * multi metrics * first version of test * change evaluate metric * Update tests/test_models/test_heads.py Co-authored-by: Colle * Update tests/test_models/test_heads.py Co-authored-by: Colle * add tests * add test for multidatasample * create a generic test * create a generic test * create a generic test * change multi data sample * correct test * test * add new test * add test for dataset * correct test * correct test * correct test * correct test * fix : #5 * run yapf * fix linter * fix linter * fix linter * fix isort * fix isort * fix docformmater * fix docformmater * fix linter * fix linter * fix data sample * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle * Update tests/test_structures/test_datasample.py Co-authored-by: Colle * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle * Update tests/test_structures/test_datasample.py Co-authored-by: Colle * Update tests/test_structures/test_datasample.py Co-authored-by: Colle * update data sample * update head * update head * update multi data sample * fix linter * fix linter * fix linter * fix linter * fix linter * fix linter * update head * fix problem we don't set pred or gt * fix problem we don't set pred or gt * fix problem we don't set pred or gt * fix linter * fix : #2 * fix : linter * update multi head * fix linter * fix linter * update data sample * update data sample * fix ; linter * update test * test pipeline * update pipeline * update test * update dataset * update dataset * fix linter * fix linter * update formatting * add test for multi-task-eval * update formatting * fix linter * update test * update * add test * update metrics * update metrics * add doc for functions * fix linter * training for multitask 1.x * fix linter * run flake8 * run linter * update test * add mask in evaluation * update metric doc * update metric doc * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle * update metric doc * update metric doc * Fix cannot import name MultiTaskDataSample * fix test_datasets * fix test_datasets * fix linter * add an example of multitask * change name of configs dataset * Refactor the multi-task support * correct test and metric * add test to multidatasample * add test to multidatasample * correct test * correct metrics and clshead * Update mmcls/models/heads/cls_head.py Co-authored-by: Colle * update cls_head.py documentation * lint * lint * fix: lint * fix linter * add eval mask * fix documentation * fix: single_label.py back to 1.x * Update mmcls/models/heads/multi_task_head.py Co-authored-by: Ma Zerun * Remove multi-task configs. Co-authored-by: mzr1996 Co-authored-by: HinGwenWoong Co-authored-by: Ming-Hsuan-Tu Co-authored-by: Lei Lei <18294546+Crescent-Saturn@users.noreply.github.com> Co-authored-by: WRH <12756472+wangruohui@users.noreply.github.com> Co-authored-by: marouaneamz Co-authored-by: marouane amzil <53240092+marouaneamz@users.noreply.github.com> --- mmcls/datasets/__init__.py | 3 +- mmcls/datasets/multi_task.py | 344 ++++++++++++++++++ mmcls/datasets/transforms/__init__.py | 6 +- mmcls/datasets/transforms/formatting.py | 90 ++++- mmcls/evaluation/metrics/__init__.py | 3 +- mmcls/evaluation/metrics/multi_task.py | 120 ++++++ mmcls/models/heads/__init__.py | 4 +- mmcls/models/heads/cls_head.py | 30 +- mmcls/models/heads/multi_task_head.py | 139 +++++++ mmcls/models/utils/data_preprocessor.py | 9 +- mmcls/structures/__init__.py | 3 +- mmcls/structures/multi_task_data_sample.py | 10 + tests/__init__.py | 1 + tests/data/dataset/multi-task.json | 40 ++ tests/test_datasets/test_datasets.py | 75 +++- .../test_transforms/test_formatting.py | 58 ++- .../test_metrics/test_multi_task_metrics.py | 134 +++++++ tests/test_models/test_heads.py | 138 ++++++- tests/test_structures/test_datasample.py | 19 +- 19 files changed, 1185 insertions(+), 41 deletions(-) create mode 100644 mmcls/datasets/multi_task.py create mode 100644 mmcls/evaluation/metrics/multi_task.py create mode 100644 mmcls/models/heads/multi_task_head.py create mode 100644 mmcls/structures/multi_task_data_sample.py create mode 100644 tests/__init__.py create mode 100644 tests/data/dataset/multi-task.json create mode 100644 tests/test_evaluation/test_metrics/test_multi_task_metrics.py diff --git a/mmcls/datasets/__init__.py b/mmcls/datasets/__init__.py index 0097a693299..22abdadcc51 100644 --- a/mmcls/datasets/__init__.py +++ b/mmcls/datasets/__init__.py @@ -8,6 +8,7 @@ from .imagenet import ImageNet, ImageNet21k from .mnist import MNIST, FashionMNIST from .multi_label import MultiLabelDataset +from .multi_task import MultiTaskDataset from .samplers import * # noqa: F401,F403 from .transforms import * # noqa: F401,F403 from .voc import VOC @@ -15,5 +16,5 @@ __all__ = [ 'BaseDataset', 'ImageNet', 'CIFAR10', 'CIFAR100', 'MNIST', 'FashionMNIST', 'VOC', 'build_dataset', 'ImageNet21k', 'KFoldDataset', 'CUB', - 'CustomDataset', 'MultiLabelDataset' + 'CustomDataset', 'MultiLabelDataset', 'MultiTaskDataset' ] diff --git a/mmcls/datasets/multi_task.py b/mmcls/datasets/multi_task.py new file mode 100644 index 00000000000..a28b4982002 --- /dev/null +++ b/mmcls/datasets/multi_task.py @@ -0,0 +1,344 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import os.path as osp +from os import PathLike +from typing import Optional, Sequence + +import mmengine +from mmcv.transforms import Compose +from mmengine.fileio import FileClient + +from .builder import DATASETS + + +def expanduser(path): + if isinstance(path, (str, PathLike)): + return osp.expanduser(path) + else: + return path + + +def isabs(uri): + return osp.isabs(uri) or ('://' in uri) + + +@DATASETS.register_module() +class MultiTaskDataset: + """Custom dataset for multi-task dataset. + + To use the dataset, please generate and provide an annotation file in the + below format: + + .. code-block:: json + + { + "metainfo": { + "tasks": + [ + 'gender' + 'wear' + ] + }, + "data_list": [ + { + "img_path": "a.jpg", + gt_label:{ + "gender": 0, + "wear": [1, 0, 1, 0] + } + }, + { + "img_path": "b.jpg", + gt_label:{ + "gender": 1, + "wear": [1, 0, 1, 0] + } + } + ] + } + + Assume we put our dataset in the ``data/mydataset`` folder in the + repository and organize it as the below format: :: + + mmclassification/ + └── data + └── mydataset + ├── annotation + │   ├── train.json + │   ├── test.json + │   └── val.json + ├── train + │   ├── a.jpg + │   └── ... + ├── test + │   ├── b.jpg + │   └── ... + └── val + ├── c.jpg + └── ... + + We can use the below config to build datasets: + + .. code:: python + + >>> from mmcls.datasets import build_dataset + >>> train_cfg = dict( + ... type="MultiTaskDataset", + ... ann_file="annotation/train.json", + ... data_root="data/mydataset", + ... # The `img_path` field in the train annotation file is relative + ... # to the `train` folder. + ... data_prefix='train', + ... ) + >>> train_dataset = build_dataset(train_cfg) + + Or we can put all files in the same folder: :: + + mmclassification/ + └── data + └── mydataset + ├── train.json + ├── test.json + ├── val.json + ├── a.jpg + ├── b.jpg + ├── c.jpg + └── ... + + And we can use the below config to build datasets: + + .. code:: python + + >>> from mmcls.datasets import build_dataset + >>> train_cfg = dict( + ... type="MultiTaskDataset", + ... ann_file="train.json", + ... data_root="data/mydataset", + ... # the `data_prefix` is not required since all paths are + ... # relative to the `data_root`. + ... ) + >>> train_dataset = build_dataset(train_cfg) + + + Args: + ann_file (str): The annotation file path. It can be either absolute + path or relative path to the ``data_root``. + metainfo (dict, optional): The extra meta information. It should be + a dict with the same format as the ``"metainfo"`` field in the + annotation file. Defaults to None. + data_root (str, optional): The root path of the data directory. It's + the prefix of the ``data_prefix`` and the ``ann_file``. And it can + be a remote path like "s3://openmmlab/xxx/". Defaults to None. + data_prefix (str, optional): The base folder relative to the + ``data_root`` for the ``"img_path"`` field in the annotation file. + Defaults to None. + pipeline (Sequence[dict]): A list of dict, where each element + represents a operation defined in :mod:`mmcls.datasets.pipelines`. + Defaults to an empty tuple. + test_mode (bool): in train mode or test mode. Defaults to False. + file_client_args (dict, optional): Arguments to instantiate a + FileClient. See :class:`mmengine.fileio.FileClient` for details. + If None, automatically inference from the ``data_root``. + Defaults to None. + """ + METAINFO = dict() + + def __init__(self, + ann_file: str, + metainfo: Optional[dict] = None, + data_root: Optional[str] = None, + data_prefix: Optional[str] = None, + pipeline: Sequence = (), + test_mode: bool = False, + file_client_args: Optional[dict] = None): + + self.data_root = expanduser(data_root) + + # Inference the file client + if self.data_root is not None: + file_client = FileClient.infer_client( + file_client_args, uri=self.data_root) + else: + file_client = FileClient(file_client_args) + self.file_client: FileClient = file_client + + self.ann_file = self._join_root(expanduser(ann_file)) + self.data_prefix = self._join_root(data_prefix) + + self.test_mode = test_mode + self.pipeline = Compose(pipeline) + self.data_list = self.load_data_list(self.ann_file, metainfo) + + def _join_root(self, path): + """Join ``self.data_root`` with the specified path. + + If the path is an absolute path, just return the path. And if the + path is None, return ``self.data_root``. + + Examples: + >>> self.data_root = 'a/b/c' + >>> self._join_root('d/e/') + 'a/b/c/d/e' + >>> self._join_root('https://openmmlab.com') + 'https://openmmlab.com' + >>> self._join_root(None) + 'a/b/c' + """ + if path is None: + return self.data_root + if isabs(path): + return path + + joined_path = self.file_client.join_path(self.data_root, path) + return joined_path + + @classmethod + def _get_meta_info(cls, in_metainfo: dict = None) -> dict: + """Collect meta information from the dictionary of meta. + + Args: + in_metainfo (dict): Meta information dict. + + Returns: + dict: Parsed meta information. + """ + # `cls.METAINFO` will be overwritten by in_meta + metainfo = copy.deepcopy(cls.METAINFO) + if in_metainfo is None: + return metainfo + + metainfo.update(in_metainfo) + + return metainfo + + def load_data_list(self, ann_file, metainfo_override=None): + """Load annotations from an annotation file. + + Args: + ann_file (str): Absolute annotation file path if ``self.root=None`` + or relative path if ``self.root=/path/to/data/``. + + Returns: + list[dict]: A list of annotation. + """ + annotations = mmengine.load(ann_file) + if not isinstance(annotations, dict): + raise TypeError(f'The annotations loaded from annotation file ' + f'should be a dict, but got {type(annotations)}!') + if 'data_list' not in annotations: + raise ValueError('The annotation file must have the `data_list` ' + 'field.') + metainfo = annotations.get('metainfo', {}) + raw_data_list = annotations['data_list'] + + # Set meta information. + assert isinstance(metainfo, dict), 'The `metainfo` field in the '\ + f'annotation file should be a dict, but got {type(metainfo)}' + if metainfo_override is not None: + assert isinstance(metainfo_override, dict), 'The `metainfo` ' \ + f'argument should be a dict, but got {type(metainfo_override)}' + metainfo.update(metainfo_override) + self._metainfo = self._get_meta_info(metainfo) + + data_list = [] + for i, raw_data in enumerate(raw_data_list): + try: + data_list.append(self.parse_data_info(raw_data)) + except AssertionError as e: + raise RuntimeError( + f'The format check fails during parse the item {i} of ' + f'the annotation file with error: {e}') + return data_list + + def parse_data_info(self, raw_data): + """Parse raw annotation to target format. + + This method will return a dict which contains the data information of a + sample. + + Args: + raw_data (dict): Raw data information load from ``ann_file`` + + Returns: + dict: Parsed annotation. + """ + assert isinstance(raw_data, dict), \ + f'The item should be a dict, but got {type(raw_data)}' + assert 'img_path' in raw_data, \ + "The item doesn't have `img_path` field." + data = dict( + img_path=self._join_root(raw_data['img_path']), + gt_label=raw_data['gt_label'], + ) + return data + + @property + def metainfo(self) -> dict: + """Get meta information of dataset. + + Returns: + dict: meta information collected from ``cls.METAINFO``, + annotation file and metainfo argument during instantiation. + """ + return copy.deepcopy(self._metainfo) + + def prepare_data(self, idx): + """Get data processed by ``self.pipeline``. + + Args: + idx (int): The index of ``data_info``. + + Returns: + Any: Depends on ``self.pipeline``. + """ + results = copy.deepcopy(self.data_list[idx]) + return self.pipeline(results) + + def __len__(self): + """Get the length of the whole dataset. + + Returns: + int: The length of filtered dataset. + """ + return len(self.data_list) + + def __getitem__(self, idx): + """Get the idx-th image and data information of dataset after + ``self.pipeline``. + + Args: + idx (int): The index of of the data. + + Returns: + dict: The idx-th image and data information after + ``self.pipeline``. + """ + return self.prepare_data(idx) + + def __repr__(self): + """Print the basic information of the dataset. + + Returns: + str: Formatted string. + """ + head = 'Dataset ' + self.__class__.__name__ + body = [f'Number of samples: \t{self.__len__()}'] + if self.data_root is not None: + body.append(f'Root location: \t{self.data_root}') + body.append(f'Annotation file: \t{self.ann_file}') + if self.data_prefix is not None: + body.append(f'Prefix of images: \t{self.data_prefix}') + # -------------------- extra repr -------------------- + tasks = self.metainfo['tasks'] + body.append(f'For {len(tasks)} tasks') + for task in tasks: + body.append(f' {task} ') + # ---------------------------------------------------- + + if len(self.pipeline.transforms) > 0: + body.append('With transforms:') + for t in self.pipeline.transforms: + body.append(f' {t}') + + lines = [head] + [' ' * 4 + line for line in body] + return '\n'.join(lines) diff --git a/mmcls/datasets/transforms/__init__.py b/mmcls/datasets/transforms/__init__.py index 8ea8db3d78f..1fa905a56cf 100644 --- a/mmcls/datasets/transforms/__init__.py +++ b/mmcls/datasets/transforms/__init__.py @@ -3,7 +3,8 @@ Brightness, ColorTransform, Contrast, Cutout, Equalize, Invert, Posterize, RandAugment, Rotate, Sharpness, Shear, Solarize, SolarizeAdd, Translate) -from .formatting import Collect, PackClsInputs, ToNumpy, ToPIL, Transpose +from .formatting import (Collect, PackClsInputs, PackMultiTaskInputs, ToNumpy, + ToPIL, Transpose) from .processing import (Albumentations, ColorJitter, EfficientNetCenterCrop, EfficientNetRandomCrop, Lighting, RandomCrop, RandomErasing, RandomResizedCrop, ResizeEdge) @@ -15,5 +16,6 @@ 'Contrast', 'Brightness', 'Sharpness', 'AutoAugment', 'SolarizeAdd', 'Cutout', 'RandAugment', 'Lighting', 'ColorJitter', 'RandomErasing', 'PackClsInputs', 'Albumentations', 'EfficientNetRandomCrop', - 'EfficientNetCenterCrop', 'ResizeEdge', 'BaseAugTransform' + 'EfficientNetCenterCrop', 'ResizeEdge', 'BaseAugTransform', + 'PackMultiTaskInputs' ] diff --git a/mmcls/datasets/transforms/formatting.py b/mmcls/datasets/transforms/formatting.py index c413d6f3fd0..d96ffed93cb 100644 --- a/mmcls/datasets/transforms/formatting.py +++ b/mmcls/datasets/transforms/formatting.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. -import warnings +from collections import defaultdict from collections.abc import Sequence +from functools import partial import numpy as np import torch @@ -9,7 +10,7 @@ from PIL import Image from mmcls.registry import TRANSFORMS -from mmcls.structures import ClsDataSample +from mmcls.structures import ClsDataSample, MultiTaskDataSample def to_tensor(data): @@ -85,12 +86,6 @@ def transform(self, results: dict) -> dict: img = np.expand_dims(img, -1) img = np.ascontiguousarray(img.transpose(2, 0, 1)) packed_results['inputs'] = to_tensor(img) - else: - warnings.warn( - 'Cannot get "img" in the input dict of `PackClsInputs`,' - 'please make sure `LoadImageFromFile` has been added ' - 'in the data pipeline or images have been loaded in ' - 'the dataset.') data_sample = ClsDataSample() if 'gt_label' in results: @@ -100,7 +95,6 @@ def transform(self, results: dict) -> dict: img_meta = {k: results[k] for k in self.meta_keys if k in results} data_sample.set_metainfo(img_meta) packed_results['data_samples'] = data_sample - return packed_results def __repr__(self) -> str: @@ -109,6 +103,84 @@ def __repr__(self) -> str: return repr_str +@TRANSFORMS.register_module() +class PackMultiTaskInputs(BaseTransform): + """Convert all image labels of multi-task dataset to a dict of tensor. + + Args: + tasks (List[str]): The task names defined in the dataset. + meta_keys(Sequence[str]): The meta keys to be saved in the + ``metainfo`` of the packed ``data_samples``. + Defaults to a tuple includes keys: + + - ``sample_idx``: The id of the image sample. + - ``img_path``: The path to the image file. + - ``ori_shape``: The original shape of the image as a tuple (H, W). + - ``img_shape``: The shape of the image after the pipeline as a + tuple (H, W). + - ``scale_factor``: The scale factor between the resized image and + the original image. + - ``flip``: A boolean indicating if image flip transform was used. + - ``flip_direction``: The flipping direction. + """ + + def __init__(self, + task_handlers=dict(), + multi_task_fields=('gt_label', ), + meta_keys=('sample_idx', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor', 'flip', 'flip_direction')): + self.multi_task_fields = multi_task_fields + self.meta_keys = meta_keys + self.task_handlers = defaultdict( + partial(PackClsInputs, meta_keys=meta_keys)) + for task_name, task_handler in task_handlers.items(): + self.task_handlers[task_name] = TRANSFORMS.build( + dict(type=task_handler, meta_keys=meta_keys)) + + def transform(self, results: dict) -> dict: + """Method to pack the input data. + + result = {'img_path': 'a.png', 'gt_label': {'task1': 1, 'task3': 3}, + 'img': array([[[ 0, 0, 0]) + """ + packed_results = dict() + results = results.copy() + + if 'img' in results: + img = results.pop('img') + if len(img.shape) < 3: + img = np.expand_dims(img, -1) + img = np.ascontiguousarray(img.transpose(2, 0, 1)) + packed_results['inputs'] = to_tensor(img) + + task_results = defaultdict(dict) + for field in self.multi_task_fields: + if field in results: + value = results.pop(field) + for k, v in value.items(): + task_results[k].update({field: v}) + + data_sample = MultiTaskDataSample() + for task_name, task_result in task_results.items(): + task_handler = self.task_handlers[task_name] + task_pack_result = task_handler({**results, **task_result}) + data_sample.set_field(task_pack_result['data_samples'], task_name) + + packed_results['data_samples'] = data_sample + return packed_results + + def __repr__(self): + repr = self.__class__.__name__ + task_handlers = { + name: handler.__class__.__name__ + for name, handler in self.task_handlers.items() + } + repr += f'(task_handlers={task_handlers}, ' + repr += f'multi_task_fields={self.multi_task_fields}, ' + repr += f'meta_keys={self.meta_keys})' + return repr + + @TRANSFORMS.register_module() class Transpose(BaseTransform): """Transpose numpy array. diff --git a/mmcls/evaluation/metrics/__init__.py b/mmcls/evaluation/metrics/__init__.py index 25b4dc27148..78e02a291f8 100644 --- a/mmcls/evaluation/metrics/__init__.py +++ b/mmcls/evaluation/metrics/__init__.py @@ -1,9 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. from .multi_label import AveragePrecision, MultiLabelMetric +from .multi_task import MultiTasksMetric from .single_label import Accuracy, SingleLabelMetric from .voc_multi_label import VOCAveragePrecision, VOCMultiLabelMetric __all__ = [ 'Accuracy', 'SingleLabelMetric', 'MultiLabelMetric', 'AveragePrecision', - 'VOCAveragePrecision', 'VOCMultiLabelMetric' + 'MultiTasksMetric', 'VOCAveragePrecision', 'VOCMultiLabelMetric' ] diff --git a/mmcls/evaluation/metrics/multi_task.py b/mmcls/evaluation/metrics/multi_task.py new file mode 100644 index 00000000000..5f07bdd07d5 --- /dev/null +++ b/mmcls/evaluation/metrics/multi_task.py @@ -0,0 +1,120 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Sequence + +from mmengine.evaluator import BaseMetric + +from mmcls.registry import METRICS + + +@METRICS.register_module() +class MultiTasksMetric(BaseMetric): + """Metrics for MultiTask + Args: + task_metrics(dict): a dictionary in the keys are the names of the tasks + and the values is a list of the metric corresponds to this task + Examples: + >>> import torch + >>> from mmcls.evaluation import MultiTasksMetric + # -------------------- The Basic Usage -------------------- + >>>task_metrics = { + 'task0': [dict(type='Accuracy', topk=(1, ))], + 'task1': [dict(type='Accuracy', topk=(1, 3))] + } + >>>pred = [{ + 'pred_task': { + 'task0': torch.tensor([0.7, 0.0, 0.3]), + 'task1': torch.tensor([0.5, 0.2, 0.3]) + }, + 'gt_task': { + 'task0': torch.tensor(0), + 'task1': torch.tensor(2) + } + }, { + 'pred_task': { + 'task0': torch.tensor([0.0, 0.0, 1.0]), + 'task1': torch.tensor([0.0, 0.0, 1.0]) + }, + 'gt_task': { + 'task0': torch.tensor(2), + 'task1': torch.tensor(2) + } + }] + >>>metric = MultiTasksMetric(task_metrics) + >>>metric.process(None, pred) + >>>results = metric.evaluate(2) + results = { + 'task0_accuracy/top1': 100.0, + 'task1_accuracy/top1': 50.0, + 'task1_accuracy/top3': 100.0 + } + """ + + def __init__(self, + task_metrics: Dict, + collect_device: str = 'cpu') -> None: + self.task_metrics = task_metrics + super().__init__(collect_device=collect_device) + + self._metrics = {} + for task_name in self.task_metrics.keys(): + self._metrics[task_name] = [] + for metric in self.task_metrics[task_name]: + self._metrics[task_name].append(METRICS.build(metric)) + + def process(self, data_batch, data_samples: Sequence[dict]): + """Process one batch of data samples. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + Args: + data_batch: A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + for task_name in self.task_metrics.keys(): + filtered_data_samples = [] + for data_sample in data_samples: + eval_mask = data_sample[task_name]['eval_mask'] + if eval_mask: + filtered_data_samples.append(data_sample[task_name]) + for metric in self._metrics[task_name]: + metric.process(data_batch, filtered_data_samples) + + def compute_metrics(self, results: list) -> dict: + raise NotImplementedError( + 'compute metrics should not be used here directly') + + def evaluate(self, size): + """Evaluate the model performance of the whole dataset after processing + all batches. + + Args: + size (int): Length of the entire validation dataset. When batch + size > 1, the dataloader may pad some data samples to make + sure all ranks have the same length of dataset slice. The + ``collect_results`` function will drop the padded data based on + this size. + Returns: + dict: Evaluation metrics dict on the val dataset. The keys are + "{task_name}_{metric_name}" , and the values + are corresponding results. + """ + metrics = {} + for task_name in self._metrics: + for metric in self._metrics[task_name]: + name = metric.__class__.__name__ + if name == 'MultiTasksMetric' or metric.results: + results = metric.evaluate(size) + else: + results = {metric.__class__.__name__: 0} + for key in results: + name = f'{task_name}_{key}' + if name in results: + """Inspired from https://github.com/open- + mmlab/mmengine/ bl ob/ed20a9cba52ceb371f7c825131636b9e2 + 747172e/mmengine/evalua tor/evaluator.py#L84-L87.""" + raise ValueError( + 'There are multiple metric results with the same' + f'metric name {name}. Please make sure all metrics' + 'have different prefixes.') + metrics[name] = results[key] + return metrics diff --git a/mmcls/models/heads/__init__.py b/mmcls/models/heads/__init__.py index 3e359d37227..b5f5761ef46 100644 --- a/mmcls/models/heads/__init__.py +++ b/mmcls/models/heads/__init__.py @@ -8,11 +8,13 @@ from .multi_label_cls_head import MultiLabelClsHead from .multi_label_csra_head import CSRAClsHead from .multi_label_linear_head import MultiLabelLinearClsHead +from .multi_task_head import MultiTaskHead from .stacked_head import StackedLinearClsHead from .vision_transformer_head import VisionTransformerClsHead __all__ = [ 'ClsHead', 'LinearClsHead', 'StackedLinearClsHead', 'MultiLabelClsHead', 'MultiLabelLinearClsHead', 'VisionTransformerClsHead', 'DeiTClsHead', - 'ConformerHead', 'EfficientFormerClsHead', 'ArcFaceClsHead', 'CSRAClsHead' + 'ConformerHead', 'EfficientFormerClsHead', 'ArcFaceClsHead', 'CSRAClsHead', + 'MultiTaskHead' ] diff --git a/mmcls/models/heads/cls_head.py b/mmcls/models/heads/cls_head.py index 4af22f65764..1338947bf53 100644 --- a/mmcls/models/heads/cls_head.py +++ b/mmcls/models/heads/cls_head.py @@ -108,9 +108,10 @@ def _get_loss(self, cls_score: torch.Tensor, return losses def predict( - self, - feats: Tuple[torch.Tensor], - data_samples: List[ClsDataSample] = None) -> List[ClsDataSample]: + self, + feats: Tuple[torch.Tensor], + data_samples: List[Union[ClsDataSample, None]] = None + ) -> List[ClsDataSample]: """Inference without augmentation. Args: @@ -118,7 +119,7 @@ def predict( Multiple stage inputs are acceptable but only the last stage will be used to classify. The shape of every item should be ``(num_samples, num_classes)``. - data_samples (List[ClsDataSample], optional): The annotation + data_samples (List[ClsDataSample | None], optional): The annotation data of every samples. If not None, set ``pred_label`` of the input data samples. Defaults to None. @@ -141,14 +142,15 @@ def _get_predictions(self, cls_score, data_samples): pred_scores = F.softmax(cls_score, dim=1) pred_labels = pred_scores.argmax(dim=1, keepdim=True).detach() - if data_samples is not None: - for data_sample, score, label in zip(data_samples, pred_scores, - pred_labels): - data_sample.set_pred_score(score).set_pred_label(label) - else: - data_samples = [] - for score, label in zip(pred_scores, pred_labels): - data_samples.append(ClsDataSample().set_pred_score( - score).set_pred_label(label)) + out_data_samples = [] + if data_samples is None: + data_samples = [None for _ in range(pred_scores.size(0))] + + for data_sample, score, label in zip(data_samples, pred_scores, + pred_labels): + if data_sample is None: + data_sample = ClsDataSample() - return data_samples + data_sample.set_pred_score(score).set_pred_label(label) + out_data_samples.append(data_sample) + return out_data_samples diff --git a/mmcls/models/heads/multi_task_head.py b/mmcls/models/heads/multi_task_head.py new file mode 100644 index 00000000000..64167739f65 --- /dev/null +++ b/mmcls/models/heads/multi_task_head.py @@ -0,0 +1,139 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Sequence, Tuple + +import torch +import torch.nn as nn +from mmengine.model import ModuleDict + +from mmcls.registry import MODELS +from mmcls.structures import MultiTaskDataSample +from .base_head import BaseHead + + +def loss_convertor(loss_func, task_name): + + def wrapped(inputs, data_samples, **kwargs): + mask = torch.empty(len(data_samples), dtype=torch.bool) + task_data_samples = [] + for i, data_sample in enumerate(data_samples): + assert isinstance(data_sample, MultiTaskDataSample) + sample_mask = task_name in data_sample + mask[i] = sample_mask + if sample_mask: + task_data_samples.append(data_sample.get(task_name)) + + if len(task_data_samples) == 0: + return {'loss': torch.tensor(0.), 'mask_size': torch.tensor(0.)} + + # Mask the inputs of the task + def mask_inputs(inputs, mask): + if isinstance(inputs, Sequence): + return type(inputs)( + [mask_inputs(input, mask) for input in inputs]) + elif isinstance(inputs, torch.Tensor): + return inputs[mask] + + masked_inputs = mask_inputs(inputs, mask) + loss_output = loss_func(masked_inputs, task_data_samples, **kwargs) + loss_output['mask_size'] = mask.sum().to(torch.float) + return loss_output + + return wrapped + + +@MODELS.register_module() +class MultiTaskHead(BaseHead): + """Multi task head. + + Args: + task_heads (dict): Sub heads to use, the key will be use to rename the + loss components. + common_cfg (dict): The common settings for all heads. Defaults to an + empty dict. + init_cfg (dict, optional): The extra initialization settings. + Defaults to None. + """ + + def __init__(self, task_heads, init_cfg=None, **kwargs): + super(MultiTaskHead, self).__init__(init_cfg=init_cfg) + + assert isinstance(task_heads, dict), 'The `task_heads` argument' \ + "should be a dict, which's keys are task names and values are" \ + 'configs of head for the task.' + + self.task_heads = ModuleDict() + + for task_name, sub_head in task_heads.items(): + if not isinstance(sub_head, nn.Module): + sub_head = MODELS.build(sub_head, default_args=kwargs) + sub_head.loss = loss_convertor(sub_head.loss, task_name) + self.task_heads[task_name] = sub_head + + def forward(self, feats): + """The forward process.""" + return { + task_name: head(feats) + for task_name, head in self.task_heads.items() + } + + def loss(self, feats: Tuple[torch.Tensor], + data_samples: List[MultiTaskDataSample], **kwargs) -> dict: + """Calculate losses from the classification score. + + Args: + feats (tuple[Tensor]): The features extracted from the backbone. + data_samples (List[MultiTaskDataSample]): The annotation data of + every samples. + **kwargs: Other keyword arguments to forward the loss module. + + Returns: + dict[str, Tensor]: a dictionary of loss components, each task loss + key will be prefixed by the task_name like "task1_loss" + """ + losses = dict() + for task_name, head in self.task_heads.items(): + head_loss = head.loss(feats, data_samples, **kwargs) + for k, v in head_loss.items(): + losses[f'{task_name}_{k}'] = v + return losses + + def predict( + self, + feats: Tuple[torch.Tensor], + data_samples: List[MultiTaskDataSample] = None + ) -> List[MultiTaskDataSample]: + """Inference without augmentation. + + Args: + feats (tuple[Tensor]): The features extracted from the backbone. + data_samples (List[MultiTaskDataSample], optional): The annotation + data of every samples. If not None, set ``pred_label`` of + the input data samples. Defaults to None. + + Returns: + List[MultiTaskDataSample]: A list of data samples which contains + the predicted results. + """ + predictions_dict = dict() + + for task_name, head in self.task_heads.items(): + task_samples = head.predict(feats) + batch_size = len(task_samples) + predictions_dict[task_name] = task_samples + + if data_samples is None: + data_samples = [MultiTaskDataSample() for _ in range(batch_size)] + + for task_name, task_samples in predictions_dict.items(): + for data_sample, task_sample in zip(data_samples, task_samples): + task_sample.set_field( + task_name in data_sample.tasks, + 'eval_mask', + field_type='metainfo') + + if task_name in data_sample.tasks: + data_sample.get(task_name).update(task_sample) + else: + data_sample.set_field(task_sample, task_name) + + return data_samples diff --git a/mmcls/models/utils/data_preprocessor.py b/mmcls/models/utils/data_preprocessor.py index 1da730c2f35..716b0a1eafa 100644 --- a/mmcls/models/utils/data_preprocessor.py +++ b/mmcls/models/utils/data_preprocessor.py @@ -8,7 +8,8 @@ from mmengine.model import BaseDataPreprocessor, stack_batch from mmcls.registry import MODELS -from mmcls.structures import (batch_label_to_onehot, cat_batch_labels, +from mmcls.structures import (ClsDataSample, MultiTaskDataSample, + batch_label_to_onehot, cat_batch_labels, stack_batch_scores, tensor_split) from .batch_augments import RandomBatchAugment @@ -151,7 +152,9 @@ def forward(self, data: dict, training: bool = False) -> dict: self.pad_value) data_samples = data.get('data_samples', None) - if data_samples is not None and 'gt_label' in data_samples[0]: + sample_item = data_samples[0] if data_samples is not None else None + if isinstance(sample_item, + ClsDataSample) and 'gt_label' in sample_item: gt_labels = [sample.gt_label for sample in data_samples] batch_label, label_indices = cat_batch_labels( gt_labels, device=self.device) @@ -181,5 +184,7 @@ def forward(self, data: dict, training: bool = False) -> dict: if batch_score is not None: for sample, score in zip(data_samples, batch_score): sample.set_gt_score(score) + elif isinstance(sample_item, MultiTaskDataSample): + data_samples = self.cast_data(data_samples) return {'inputs': inputs, 'data_samples': data_samples} diff --git a/mmcls/structures/__init__.py b/mmcls/structures/__init__.py index 0dc08443cab..3021d0a7d0b 100644 --- a/mmcls/structures/__init__.py +++ b/mmcls/structures/__init__.py @@ -1,9 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. from .cls_data_sample import ClsDataSample +from .multi_task_data_sample import MultiTaskDataSample from .utils import (batch_label_to_onehot, cat_batch_labels, stack_batch_scores, tensor_split) __all__ = [ 'ClsDataSample', 'batch_label_to_onehot', 'cat_batch_labels', - 'stack_batch_scores', 'tensor_split' + 'stack_batch_scores', 'tensor_split', 'MultiTaskDataSample' ] diff --git a/mmcls/structures/multi_task_data_sample.py b/mmcls/structures/multi_task_data_sample.py new file mode 100644 index 00000000000..f00993861bf --- /dev/null +++ b/mmcls/structures/multi_task_data_sample.py @@ -0,0 +1,10 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from mmengine.structures import BaseDataElement + + +class MultiTaskDataSample(BaseDataElement): + + @property + def tasks(self): + return self._data_fields diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000000..ef101fec61e --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/tests/data/dataset/multi-task.json b/tests/data/dataset/multi-task.json new file mode 100644 index 00000000000..bf96384a0e6 --- /dev/null +++ b/tests/data/dataset/multi-task.json @@ -0,0 +1,40 @@ +{ + "metainfo": { + "tasks": [ + "gender", + "wear" + ] + }, + "data_list": [ + { + "img_path": "a/1.JPG", + "gt_label": { + "gender": 0 + } + }, + { + "img_path": "b/2.jpeg", + "gt_label": { + "gender": 0, + "wear": [ + 1, + 0, + 1, + 0 + ] + } + }, + { + "img_path": "b/subb/3.jpg", + "gt_label": { + "gender": 1, + "wear": [ + 0, + 1, + 0, + 1 + ] + } + } + ] +} diff --git a/tests/test_datasets/test_datasets.py b/tests/test_datasets/test_datasets.py index f637fb9580d..eb2fab213e5 100644 --- a/tests/test_datasets/test_datasets.py +++ b/tests/test_datasets/test_datasets.py @@ -79,7 +79,7 @@ def test_repr(self): else: self.assertIn('The `CLASSES` meta info is not set.', repr(dataset)) - self.assertIn("Haven't been initialized", repr(dataset)) + self.assertIn('Haven\'t been initialized', repr(dataset)) dataset.full_init() self.assertIn(f'Number of samples: \t{len(dataset)}', repr(dataset)) @@ -452,7 +452,7 @@ def test_extra_repr(self): cfg = {**self.DEFAULT_ARGS, 'lazy_init': True} dataset = dataset_class(**cfg) - self.assertIn(f'Prefix of data: \t{dataset.data_prefix["root"]}', + self.assertIn(f"Prefix of data: \t{dataset.data_prefix['root']}", repr(dataset)) @classmethod @@ -597,7 +597,7 @@ def test_load_data_list(self): } dataset = dataset_class(**cfg) - self.assertIn("Haven't been initialized", repr(dataset)) + self.assertIn('Haven\'t been initialized', repr(dataset)) dataset.full_init() self.assertIn(f'Number of samples: \t{len(dataset)}', repr(dataset)) @@ -770,7 +770,7 @@ def test_extra_repr(self): cfg = {**self.DEFAULT_ARGS, 'lazy_init': True} dataset = dataset_class(**cfg) - self.assertIn(f'Prefix of data: \t{dataset.data_prefix["root"]}', + self.assertIn(f"Prefix of data: \t{dataset.data_prefix['root']}", repr(dataset)) @classmethod @@ -874,3 +874,70 @@ def test_extra_repr(self): @classmethod def tearDownClass(cls): cls.tmpdir.cleanup() + + +class TestMultiTaskDataset(TestCase): + DATASET_TYPE = 'MultiTaskDataset' + + DEFAULT_ARGS = dict( + data_root=ASSETS_ROOT, + ann_file=osp.join(ASSETS_ROOT, 'multi-task.json'), + pipeline=[]) + + def test_metainfo(self): + dataset_class = DATASETS.get(self.DATASET_TYPE) + + # Test default behavior + dataset = dataset_class(**self.DEFAULT_ARGS) + metainfo = {'tasks': ['gender', 'wear']} + self.assertDictEqual(dataset.metainfo, metainfo) + self.assertFalse(dataset.test_mode) + + def test_parse_data_info(self): + dataset_class = DATASETS.get(self.DATASET_TYPE) + dataset = dataset_class(**self.DEFAULT_ARGS) + + data = dataset.parse_data_info({ + 'img_path': 'a.jpg', + 'gt_label': { + 'gender': 0 + } + }) + self.assertDictContainsSubset( + { + 'img_path': os.path.join(ASSETS_ROOT, 'a.jpg'), + 'gt_label': { + 'gender': 0 + } + }, data) + np.testing.assert_equal(data['gt_label']['gender'], 0) + + # Test missing path + with self.assertRaisesRegex(AssertionError, 'have `img_path` field'): + dataset.parse_data_info( + {'gt_label': { + 'gender': 0, + 'wear': [1, 0, 1, 0] + }}) + + def test_repr(self): + dataset_class = DATASETS.get(self.DATASET_TYPE) + dataset = dataset_class(**self.DEFAULT_ARGS) + + task_doc = ('For 2 tasks\n gender \n wear ') + self.assertIn(task_doc, repr(dataset)) + + def test_load_data_list(self): + dataset_class = DATASETS.get(self.DATASET_TYPE) + + # Test default behavior + dataset = dataset_class(**self.DEFAULT_ARGS) + + data = dataset.load_data_list(self.DEFAULT_ARGS['ann_file']) + self.assertIsInstance(data, list) + np.testing.assert_equal(len(data), 3) + np.testing.assert_equal(data[0]['gt_label'], {'gender': 0}) + np.testing.assert_equal(data[1]['gt_label'], { + 'gender': 0, + 'wear': [1, 0, 1, 0] + }) diff --git a/tests/test_datasets/test_transforms/test_formatting.py b/tests/test_datasets/test_transforms/test_formatting.py index 0d271b3b0c6..6806b0b8060 100644 --- a/tests/test_datasets/test_transforms/test_formatting.py +++ b/tests/test_datasets/test_transforms/test_formatting.py @@ -10,7 +10,7 @@ from PIL import Image from mmcls.registry import TRANSFORMS -from mmcls.structures import ClsDataSample +from mmcls.structures import ClsDataSample, MultiTaskDataSample from mmcls.utils import register_all_modules register_all_modules() @@ -51,9 +51,8 @@ def test_transform(self): # Test without `img` and `gt_label` del data['img'] del data['gt_label'] - with self.assertWarnsRegex(Warning, 'Cannot get "img"'): - results = transform(copy.deepcopy(data)) - self.assertNotIn('gt_label', results['data_samples']) + results = transform(copy.deepcopy(data)) + self.assertNotIn('gt_label', results['data_samples']) def test_repr(self): cfg = dict(type='PackClsInputs', meta_keys=['flip', 'img_shape']) @@ -130,3 +129,54 @@ def test_repr(self): cfg = dict(type='Collect', keys=['img']) transform = TRANSFORMS.build(cfg) self.assertEqual(repr(transform), "Collect(keys=['img'])") + + +class TestPackMultiTaskInputs(unittest.TestCase): + + def test_transform(self): + img_path = osp.join(osp.dirname(__file__), '../../data/color.jpg') + data = { + 'sample_idx': 1, + 'img_path': img_path, + 'ori_shape': (300, 400), + 'img_shape': (300, 400), + 'scale_factor': 1.0, + 'flip': False, + 'img': mmcv.imread(img_path), + 'gt_label': { + 'task1': 1, + 'task3': 3 + }, + } + + cfg = dict(type='PackMultiTaskInputs', ) + transform = TRANSFORMS.build(cfg) + results = transform(copy.deepcopy(data)) + self.assertIn('inputs', results) + self.assertIsInstance(results['inputs'], torch.Tensor) + self.assertIn('data_samples', results) + self.assertIsInstance(results['data_samples'], MultiTaskDataSample) + self.assertIn('flip', results['data_samples'].task1.metainfo_keys()) + self.assertIsInstance(results['data_samples'].task1.gt_label, + LabelData) + + # Test grayscale image + data['img'] = data['img'].mean(-1) + results = transform(copy.deepcopy(data)) + self.assertIn('inputs', results) + self.assertIsInstance(results['inputs'], torch.Tensor) + self.assertEqual(results['inputs'].shape, (1, 300, 400)) + + # Test without `img` and `gt_label` + del data['img'] + del data['gt_label'] + results = transform(copy.deepcopy(data)) + self.assertNotIn('gt_label', results['data_samples']) + + def test_repr(self): + cfg = dict(type='PackMultiTaskInputs', meta_keys=['img_shape']) + transform = TRANSFORMS.build(cfg) + rep = 'PackMultiTaskInputs(task_handlers={},' + rep += ' multi_task_fields=(\'gt_label\',),' + rep += ' meta_keys=[\'img_shape\'])' + self.assertEqual(repr(transform), rep) diff --git a/tests/test_evaluation/test_metrics/test_multi_task_metrics.py b/tests/test_evaluation/test_metrics/test_multi_task_metrics.py new file mode 100644 index 00000000000..29e4d96d414 --- /dev/null +++ b/tests/test_evaluation/test_metrics/test_multi_task_metrics.py @@ -0,0 +1,134 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch + +from mmcls.evaluation.metrics import MultiTasksMetric +from mmcls.structures import ClsDataSample + + +class MultiTaskMetric(TestCase): + data_pred = [ + { + 'task0': torch.tensor([0.7, 0.0, 0.3]), + 'task1': torch.tensor([0.5, 0.2, 0.3]) + }, + { + 'task0': torch.tensor([0.0, 0.0, 1.0]), + 'task1': torch.tensor([0.0, 0.0, 1.0]) + }, + ] + data_gt = [{'task0': 0, 'task1': 2}, {'task1': 2}] + + preds = [] + for i, pred in enumerate(data_pred): + sample = {} + for task_name in pred: + task_sample = ClsDataSample().set_pred_score(pred[task_name]) + if task_name in data_gt[i]: + task_sample.set_gt_label(data_gt[i][task_name]) + task_sample.set_field(True, 'eval_mask', field_type='metainfo') + else: + task_sample.set_field( + False, 'eval_mask', field_type='metainfo') + sample[task_name] = task_sample.to_dict() + + preds.append(sample) + data2 = zip([ + { + 'task0': torch.tensor([0.7, 0.0, 0.3]), + 'task1': { + 'task10': torch.tensor([0.5, 0.2, 0.3]), + 'task11': torch.tensor([0.4, 0.3, 0.3]) + } + }, + { + 'task0': torch.tensor([0.0, 0.0, 1.0]), + 'task1': { + 'task10': torch.tensor([0.1, 0.6, 0.3]), + 'task11': torch.tensor([0.5, 0.2, 0.3]) + } + }, + ], [{ + 'task0': 0, + 'task1': { + 'task10': 2, + 'task11': 0 + } + }, { + 'task0': 2, + 'task1': { + 'task10': 1, + 'task11': 0 + } + }]) + + pred2 = [] + for score, label in data2: + sample = {} + for task_name in score: + if type(score[task_name]) != dict: + task_sample = ClsDataSample().set_pred_score(score[task_name]) + task_sample.set_gt_label(label[task_name]) + sample[task_name] = task_sample.to_dict() + sample[task_name]['eval_mask'] = True + else: + sample[task_name] = {} + sample[task_name]['eval_mask'] = True + for task_name2 in score[task_name]: + task_sample = ClsDataSample().set_pred_score( + score[task_name][task_name2]) + task_sample.set_gt_label(label[task_name][task_name2]) + sample[task_name][task_name2] = task_sample.to_dict() + sample[task_name][task_name2]['eval_mask'] = True + + pred2.append(sample) + + pred3 = [{'task0': {'eval_mask': False}, 'task1': {'eval_mask': False}}] + task_metrics = { + 'task0': [dict(type='Accuracy', topk=(1, ))], + 'task1': [ + dict(type='Accuracy', topk=(1, 3)), + dict(type='SingleLabelMetric', items=['precision', 'recall']) + ] + } + task_metrics2 = { + 'task0': [dict(type='Accuracy', topk=(1, ))], + 'task1': [ + dict( + type='MultiTasksMetric', + task_metrics={ + 'task10': [ + dict(type='Accuracy', topk=(1, 3)), + dict(type='SingleLabelMetric', items=['precision']) + ], + 'task11': [dict(type='Accuracy', topk=(1, ))] + }) + ] + } + + def test_evaluate(self): + """Test using the metric in the same way as Evalutor.""" + + # Test with score (use score instead of label if score exists) + metric = MultiTasksMetric(self.task_metrics) + metric.process(None, self.preds) + results = metric.evaluate(2) + self.assertIsInstance(results, dict) + self.assertAlmostEqual(results['task0_accuracy/top1'], 100) + self.assertGreater(results['task1_single-label/precision'], 0) + + # Test nested + metric = MultiTasksMetric(self.task_metrics2) + metric.process(None, self.pred2) + results = metric.evaluate(2) + self.assertIsInstance(results, dict) + self.assertGreater(results['task1_task10_single-label/precision'], 0) + self.assertGreater(results['task1_task11_accuracy/top1'], 0) + + # Test with without any ground truth value + metric = MultiTasksMetric(self.task_metrics) + metric.process(None, self.pred3) + results = metric.evaluate(2) + self.assertIsInstance(results, dict) + self.assertEqual(results['task0_Accuracy'], 0) diff --git a/tests/test_models/test_heads.py b/tests/test_models/test_heads.py index 0b1f72f1db8..85fdd7aa1fa 100644 --- a/tests/test_models/test_heads.py +++ b/tests/test_models/test_heads.py @@ -10,7 +10,7 @@ from mmengine import is_seq_of from mmcls.registry import MODELS -from mmcls.structures import ClsDataSample +from mmcls.structures import ClsDataSample, MultiTaskDataSample from mmcls.utils import register_all_modules register_all_modules() @@ -484,6 +484,142 @@ def test_forward(self): head(feats) +class TestMultiTaskHead(TestCase): + DEFAULT_ARGS = dict( + type='MultiTaskHead', # <- Head config, depends on #675 + task_heads={ + 'task0': dict(type='LinearClsHead', num_classes=3), + 'task1': dict(type='LinearClsHead', num_classes=6), + }, + in_channels=10, + loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + ) + + DEFAULT_ARGS2 = dict( + type='MultiTaskHead', # <- Head config, depends on #675 + task_heads={ + 'task0': + dict( + type='MultiTaskHead', + task_heads={ + 'task00': dict(type='LinearClsHead', num_classes=3), + 'task01': dict(type='LinearClsHead', num_classes=6), + }), + 'task1': + dict(type='LinearClsHead', num_classes=6) + }, + in_channels=10, + loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + ) + + def test_forward(self): + head = MODELS.build(self.DEFAULT_ARGS) + # return the last item (same as pre_logits) + feats = (torch.rand(4, 10), ) + outs = head(feats) + self.assertEqual(outs['task0'].shape, (4, 3)) + self.assertEqual(outs['task1'].shape, (4, 6)) + self.assertTrue(isinstance(outs, dict)) + + def test_loss(self): + feats = (torch.rand(4, 10), ) + data_samples = [] + + for _ in range(4): + data_sample = MultiTaskDataSample() + for task_name in self.DEFAULT_ARGS['task_heads']: + task_sample = ClsDataSample().set_gt_label(1) + data_sample.set_field(task_sample, task_name) + data_samples.append(data_sample) + # with cal_acc = False + head = MODELS.build(self.DEFAULT_ARGS) + + losses = head.loss(feats, data_samples) + self.assertEqual( + losses.keys(), + {'task0_loss', 'task0_mask_size', 'task1_loss', 'task1_mask_size'}) + self.assertGreater(losses['task0_loss'].item(), 0) + self.assertGreater(losses['task1_loss'].item(), 0) + + def test_predict(self): + feats = (torch.rand(4, 10), ) + data_samples = [] + + for _ in range(4): + data_sample = MultiTaskDataSample() + for task_name in self.DEFAULT_ARGS['task_heads']: + task_sample = ClsDataSample().set_gt_label(1) + data_sample.set_field(task_sample, task_name) + data_samples.append(data_sample) + head = MODELS.build(self.DEFAULT_ARGS) + # with without data_samples + predictions = head.predict(feats) + self.assertTrue(is_seq_of(predictions, MultiTaskDataSample)) + for pred in predictions: + self.assertIn('task0', pred) + task0_sample = predictions[0].task0 + self.assertTrue(type(task0_sample.pred_label.score), 'torch.tensor') + + # with with data_samples + predictions = head.predict(feats, data_samples) + self.assertTrue(is_seq_of(predictions, MultiTaskDataSample)) + for sample, pred in zip(data_samples, predictions): + self.assertIs(sample, pred) + self.assertIn('task0', pred) + + def test_loss_empty_data_sample(self): + feats = (torch.rand(4, 10), ) + data_samples = [] + + for _ in range(4): + data_sample = MultiTaskDataSample() + data_samples.append(data_sample) + # with cal_acc = False + head = MODELS.build(self.DEFAULT_ARGS) + losses = head.loss(feats, data_samples) + self.assertEqual( + losses.keys(), + {'task0_loss', 'task0_mask_size', 'task1_loss', 'task1_mask_size'}) + self.assertEqual(losses['task0_loss'].item(), 0) + self.assertEqual(losses['task1_loss'].item(), 0) + + def test_nested_multi_task_loss(self): + + head = MODELS.build(self.DEFAULT_ARGS2) + # return the last item (same as pre_logits) + feats = (torch.rand(4, 10), ) + outs = head(feats) + self.assertEqual(outs['task0']['task01'].shape, (4, 6)) + self.assertTrue(isinstance(outs, dict)) + self.assertTrue(isinstance(outs['task0'], dict)) + + def test_nested_invalid_sample(self): + feats = (torch.rand(4, 10), ) + gt_label = {'task0': 1, 'task1': 1} + head = MODELS.build(self.DEFAULT_ARGS2) + data_sample = MultiTaskDataSample() + for task_name in gt_label: + task_sample = ClsDataSample().set_gt_label(gt_label[task_name]) + data_sample.set_field(task_sample, task_name) + with self.assertRaises(Exception): + head.loss(feats, data_sample) + + def test_nested_invalid_sample2(self): + feats = (torch.rand(4, 10), ) + gt_label = {'task0': {'task00': 1, 'task01': 1}, 'task1': 1} + head = MODELS.build(self.DEFAULT_ARGS) + data_sample = MultiTaskDataSample() + task_sample = ClsDataSample().set_gt_label(gt_label['task1']) + data_sample.set_field(task_sample, 'task1') + data_sample.set_field(MultiTaskDataSample(), 'task0') + for task_name in gt_label['task0']: + task_sample = ClsDataSample().set_gt_label( + gt_label['task0'][task_name]) + data_sample.task0.set_field(task_sample, task_name) + with self.assertRaises(Exception): + head.loss(feats, data_sample) + + class TestArcFaceClsHead(TestCase): DEFAULT_ARGS = dict(type='ArcFaceClsHead', in_channels=10, num_classes=5) diff --git a/tests/test_structures/test_datasample.py b/tests/test_structures/test_datasample.py index ee45c3f24a3..e02c95fc787 100644 --- a/tests/test_structures/test_datasample.py +++ b/tests/test_structures/test_datasample.py @@ -5,7 +5,7 @@ import torch from mmengine.structures import LabelData -from mmcls.structures import ClsDataSample +from mmcls.structures import ClsDataSample, MultiTaskDataSample class TestClsDataSample(TestCase): @@ -122,3 +122,20 @@ def test_set_pred_score(self): with self.assertRaisesRegex(AssertionError, 'but got 2'): data_sample.set_pred_score( torch.tensor([[0.1, 0.1, 0.6, 0.1, 0.1]])) + + +class TestMultiTaskDataSample(TestCase): + + def test_multi_task_data_sample(self): + gt_label = {'task0': {'task00': 1, 'task01': 1}, 'task1': 1} + data_sample = MultiTaskDataSample() + task_sample = ClsDataSample().set_gt_label(gt_label['task1']) + data_sample.set_field(task_sample, 'task1') + data_sample.set_field(MultiTaskDataSample(), 'task0') + for task_name in gt_label['task0']: + task_sample = ClsDataSample().set_gt_label( + gt_label['task0'][task_name]) + data_sample.task0.set_field(task_sample, task_name) + self.assertIsInstance(data_sample.task0, MultiTaskDataSample) + self.assertIsInstance(data_sample.task1, ClsDataSample) + self.assertIsInstance(data_sample.task0.task00, ClsDataSample)