Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Multi task head #675

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/en/api/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ Heads
VisionTransformerClsHead
DeiTClsHead
ConformerHead
MultiTaskClsHead

.. _losses:

Expand Down
3 changes: 2 additions & 1 deletion mmcls/models/heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
from .linear_head import LinearClsHead
from .multi_label_head import MultiLabelClsHead
from .multi_label_linear_head import MultiLabelLinearClsHead
from .multi_task_head import MultiTaskClsHead
from .stacked_head import StackedLinearClsHead
from .vision_transformer_head import VisionTransformerClsHead

__all__ = [
'ClsHead', 'LinearClsHead', 'StackedLinearClsHead', 'MultiLabelClsHead',
'MultiLabelLinearClsHead', 'VisionTransformerClsHead', 'DeiTClsHead',
'ConformerHead'
'ConformerHead', 'MultiTaskClsHead'
]
112 changes: 112 additions & 0 deletions mmcls/models/heads/multi_task_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.runner import ModuleDict

from ..builder import HEADS
from .base_head import BaseHead


@HEADS.register_module()
class MultiTaskClsHead(BaseHead):
"""Multi task head.

Args:
sub_heads (dict): Sub heads to use, the key will be use to rename the
loss components.
common_cfg (dict): The common settings for all heads. Defaults to an
empty dict.
init_cfg (dict, optional): The extra initialization settings.
Defaults to None.
"""

def __init__(self, sub_heads, common_cfg=dict(), init_cfg=None):
super(MultiTaskClsHead, self).__init__(init_cfg=init_cfg)

assert isinstance(sub_heads, dict), 'The `sub_heads` argument' \
"should be a dict, which's keys are task names and values are" \
'configs of head for the task.'

self.sub_heads = ModuleDict()

for task_name, head_cfg in sub_heads.items():
sub_head = HEADS.build(head_cfg, default_args=common_cfg)
self.sub_heads[task_name] = sub_head

def forward_train(self, x, gt_label, **kwargs):
losses = dict()
for task_name, head in self.sub_heads.items():
head_loss = head.forward_train(x, gt_label[task_name], **kwargs)
for k, v in head_loss.items():
losses[f'{task_name}_{k}'] = v
return losses

def pre_logits(self, x):
results = dict()
for task_name, head in self.sub_heads.items():
results[task_name] = head.pre_logits(x)
return results

def simple_test(self,
x,
post_process=True,
task_wise_args=dict(),
**kwargs):
"""Inference without augmentation.

Args:
x (tuple[Tensor]): The input features will be forwarded to every
heads.
post_process (bool): Whether to do post-process for each task. If
True, returns a list of results and each item is a result dict
for a sample. If False, returns a dict including output without
post-process of every tasks. Defaults to True.
task_wise_args (dict): A dict of keyword arguments for different
heads.

Returns:
dict | list[dict]: The inference results. The output type depends
on ``post_process``, and more details can be found in the examples.

Examples:
>>> import torch
>>> from mmcls.models import HEADS
>>>
>>> feats = torch.rand(3, 128)
>>> cfg = dict(
... type='MultiTaskClsHead',
... sub_heads={
... 'task1': dict(num_classes=5),
... 'task2': dict(num_classes=10),
... },
... common_cfg=dict(
... type='LinearClsHead',
... in_channels=128,
... loss=dict(type='CrossEntropyLoss')),
... )
>>> head = HEADS.build(cfg)
>>> # simple_test with post_process
>>> head.simple_test(feats, post_process=True)
[{'task1': array([...], dtype=float32),
'task2': array([...], dtype=float32)},
{'task1': array([...], dtype=float32),
'task2': array([...], dtype=float32)},
{'task1': array([...], dtype=float32),
'task2': array([...], dtype=float32)}]
>>> # simple_test without post_process
>>> head.simple_test(feats, post_process=False)
{'task1': tensor(...), grad_fn=<...>),
'task2': tensor(...), grad_fn=<...>}
"""
results = dict()
for task_name, head in self.sub_heads.items():
forward_args = {
'post_process': post_process,
**kwargs,
**task_wise_args.get(task_name, {})
}
results[task_name] = head.simple_test(x, **forward_args)

if post_process:
# Convert dict of list to list of dict.
results = [dict(zip(results, t)) for t in zip(*results.values())]

return results
54 changes: 52 additions & 2 deletions tests/test_models/test_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

from mmcls.models.heads import (ClsHead, ConformerHead, DeiTClsHead,
LinearClsHead, MultiLabelClsHead,
MultiLabelLinearClsHead, StackedLinearClsHead,
VisionTransformerClsHead)
MultiLabelLinearClsHead, MultiTaskClsHead,
StackedLinearClsHead, VisionTransformerClsHead)


@pytest.mark.parametrize('feat', [torch.rand(4, 10), (torch.rand(4, 10), )])
Expand Down Expand Up @@ -317,3 +317,53 @@ def test_deit_head():
# test assertion
with pytest.raises(ValueError):
DeiTClsHead(-1, 100)


@pytest.mark.parametrize('feat', [torch.rand(4, 3), (torch.rand(4, 3), )])
def test_multitask_head(feat):
head = MultiTaskClsHead(
sub_heads={
'task1': dict(num_classes=10),
'task2': dict(num_classes=8)
},
common_cfg=dict(type='LinearClsHead', in_channels=3))
gt_label = {
'task1': torch.randint(0, 10, (4, )),
'task2': torch.randint(0, 8, (4, )),
}

losses = head.forward_train(feat, gt_label)
assert losses['task1_loss'].item() > 0
assert losses['task2_loss'].item() > 0

# test simple_test with post_process
pred = head.simple_test(feat, post_process=True)
assert isinstance(pred, list) and len(pred) == 4
assert isinstance(pred[0], dict) and pred[0].keys() == {'task1', 'task2'}
assert len(pred[0]['task1']) == 10
assert len(pred[0]['task2']) == 8

with patch('torch.onnx.is_in_onnx_export', return_value=True):
pred = head.simple_test(feat)
assert pred[0]['task1'].shape == (10, )

# test simple_test without post_process
pred = head.simple_test(feat, post_process=False)
assert isinstance(pred, dict) and pred.keys() == {'task1', 'task2'}
assert pred['task1'].shape == (4, 10)
assert pred['task2'].shape == (4, 8)

# test task_wise_args
pred_ = head.simple_test(
feat,
post_process=False,
task_wise_args=dict(task1={'softmax': False}))
assert isinstance(pred, dict) and pred.keys() == {'task1', 'task2'}
torch.testing.assert_allclose(pred_['task1'].softmax(dim=1), pred['task1'])
torch.testing.assert_allclose(pred_['task2'], pred['task2'])

# test pre_logits
pre_logits = head.pre_logits(feat)
assert isinstance(pre_logits, dict)
assert pre_logits['task1'].shape == (4, 3)
assert pre_logits['task2'].shape == (4, 3)