Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add MultiTaskDataset to support multi-task training. #808

Closed
wants to merge 1 commit into from

Conversation

mzr1996
Copy link
Member

@mzr1996 mzr1996 commented Apr 29, 2022

Motivation

To support using a single backbone to perform multiple classification tasks.

Modification

This PR is one part of the multi-task support plan, and it depends on #675 to build a network.

BC-breaking (Optional)

No

Use cases

Here is a detailed multi-task support design. First, the multi-task means using one backbone and multiple heads to do classification on an image with multiple kinds of labels.

Dataset

The current multi-task requires full labels on every image, which means you cannot use partial-labeled samples to train the multi-task model.

To create a multi-task dataset, you can use the MultiTaskDataset class and prepare an annotation file. Here is a brief example:

The annotation json file example

{
  "metainfo": {
    "tasks":
      [
        {"name": "gender",
         "type": "single-label",
         "categories": ["male", "female"]},
        {"name": "wear",
         "type": "multi-label",
         "categories": ["shirt", "coat", "jeans", "pants"]}
      ]
  },
  "data_list": [
    {
      "img_path": "a.jpg",
      "gender_img_label": 0,
      "wear_img_label": [1, 0, 1, 0]
    },
    {
      "img_path": "b.jpg",
      "gender_img_label": 1,
      "wear_img_label": [0, 1, 0, 1]
    },
    ...
  ]
}

The detailed usage and example of the MultiTaskDataset can be found here

And here is a script to use the CIFAR10 dataset to generate an example multi-task dataset, just run it in the data folder. And here is the file structure.

data/
├── cifar10
│   ├── images
│   ├── multi-task-test.json
│   └── multi-task-train.json

And here is an example config to train on the multi-task dataset.

# Save as `configs/resnet/multi-task-demo.py`
_base_ = ['../_base_/schedules/cifar10_bs128.py', '../_base_/default_runtime.py']

# model settings
model = dict(
    type='ImageClassifier',
    backbone=dict(type='ResNet_CIFAR', depth=18),
    neck=dict(type='GlobalAveragePooling'),
    head=dict(
        type='MultiTaskClsHead',                                    # <- Head config, depends on #675
        sub_heads={
            'task1': dict(type='LinearClsHead', num_classes=6),
            'task2': dict(type='LinearClsHead', num_classes=6),
        },
        common_cfg=dict(
            in_channels=512,
            loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
        ),
    ),
)

# dataset settings
dataset_type = 'MultiTaskDataset'
img_norm_cfg = dict(
    mean=[125.307, 122.961, 113.8575],
    std=[51.5865, 50.847, 51.255],
    to_rgb=False)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='RandomCrop', size=32, padding=4),
    dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='ImageToTensor', keys=['img']),
    dict(type='FormatMultiTaskLabels'),                             # <- Use this to replace `ToTensor`.
    dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='ImageToTensor', keys=['img']),
    dict(type='Collect', keys=['img'])
]
data = dict(
    samples_per_gpu=16,
    workers_per_gpu=2,
    train=dict(
        type=dataset_type,
        data_root='data/cifar10',
        ann_file='multi-task-train.json',
        pipeline=train_pipeline),
    val=dict(
        type=dataset_type,
        data_root='data/cifar10',
        ann_file='multi-task-test.json',
        pipeline=test_pipeline,
        test_mode=True),
    test=dict(
        type=dataset_type,
        data_root='data/cifar10',
        ann_file='multi-task-test.json',
        pipeline=test_pipeline,
        test_mode=True))

evaluation = dict(metric_options={
    'task1': dict(topk=(1, )),                # <- Specify different metric options for different tasks.
    'task2': dict(topk=(1, 3)),
})

Then, we can train the dataset by python tools/train.py configs/resnet/multi-task-demo.py

2022-04-29 18:25:37,968 - mmcls - INFO - workflow: [('train', 1)], max: 200 epochs
2022-04-29 18:25:37,968 - mmcls - INFO - Checkpoints will be saved to /home/work_dirs/multi-task-demo by HardDiskBackend.
2022-04-29 18:25:42,280 - mmcls - INFO - Epoch [1][100/2813]    lr: 1.000e-01, eta: 6:43:27, time: 0.043, data_time: 0.021, memory: 329, task1_loss: 1.7489, task2_loss: 1.6522, loss: 3.4011
...
2022-04-29 18:26:24,813 - mmcls - INFO - Saving checkpoint at 1 epochs
2022-04-29 18:26:26,951 - mmcls - INFO - Epoch(val) [1][313]    task1_accuracy_top-1: 62.7000, task2_accuracy_top-1: 65.6800, task2_accuracy_top-3: 96.4800

Checklist

Before PR:

  • Pre-commit or other linting tools are used to fix the potential lint issues.
  • Bug fixes are fully covered by unit tests, the case that causes the bug should be added in the unit tests.
  • The modification is covered by complete unit tests. If not, please add more unit test to ensure the correctness.
  • The documentation has been modified accordingly, like docstring or example tutorials.

After PR:

  • If the modification has potential influence on downstream or other related projects, this PR should be tested with those projects, like MMDet or MMSeg.
  • CLA has been signed and all committers have signed the CLA in this PR.

@codecov
Copy link

codecov bot commented Apr 29, 2022

Codecov Report

Merging #808 (21b0a38) into dev (59292b3) will increase coverage by 0.03%.
The diff coverage is 86.72%.

@@            Coverage Diff             @@
##              dev     #808      +/-   ##
==========================================
+ Coverage   87.02%   87.05%   +0.03%     
==========================================
  Files         130      131       +1     
  Lines        8538     8739     +201     
  Branches     1468     1512      +44     
==========================================
+ Hits         7430     7608     +178     
- Misses        888      895       +7     
- Partials      220      236      +16     
Flag Coverage Δ
unittests 86.97% <86.72%> (+0.03%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
mmcls/datasets/multi_label.py 89.74% <71.42%> (+0.85%) ⬆️
mmcls/datasets/multi_task.py 85.71% <85.71%> (ø)
mmcls/datasets/pipelines/formatting.py 54.95% <94.73%> (+11.47%) ⬆️
mmcls/datasets/__init__.py 100.00% <100.00%> (ø)
mmcls/datasets/base_dataset.py 99.00% <100.00%> (+0.03%) ⬆️
mmcls/datasets/pipelines/__init__.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update bce95b9...21b0a38. Read the comment docs.

@JihwanEom
Copy link
Contributor

Any update for this?

@iamweiweishi
Copy link

FormatMultiTaskLabels
Where can I find the above class?
Thank you.

@piercus
Copy link
Contributor

piercus commented Sep 22, 2022

@mzr1996 @Ezra-Yu This is great, thanks for this, i did only half of the job in #675 !

I have merged both MultiClsHead and MultiTaskDataset in my repo https://github.com/piercus/mmclassification/tree/multi-task

What are the next steps for this :
(1) Have you decide if this feature makes sense in the core mmcls ?
(2) Is there some clean up to do ?
(3) do we require more testing / Examples ?

Thank you for your help

@piercus piercus mentioned this pull request Nov 25, 2022
7 tasks
@mzr1996
Copy link
Member Author

mzr1996 commented Jan 12, 2023

Closed since #1229 is merged

@mzr1996 mzr1996 closed this Jan 12, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants