Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tools] Support respliting data_batch with tag #7641

Merged
merged 6 commits into from
Apr 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion mmdet/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from .logger import get_caller_name, get_root_logger, log_img_scale
from .misc import find_latest_checkpoint, update_data_root
from .setup_env import setup_multi_processes
from .split_batch import split_batch

__all__ = [
'get_root_logger', 'collect_env', 'find_latest_checkpoint',
'update_data_root', 'setup_multi_processes', 'get_caller_name',
'log_img_scale', 'compat_cfg'
'log_img_scale', 'compat_cfg', 'split_batch'
]
45 changes: 45 additions & 0 deletions mmdet/utils/split_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch


def split_batch(img, img_metas, kwargs):
"""Split data_batch by tags.

Code is modified from
<https://github.com/microsoft/SoftTeacher/blob/main/ssod/utils/structure_utils.py> # noqa: E501

Args:
img (Tensor): of shape (N, C, H, W) encoding input images.
Typically these should be mean centered and std scaled.
img_metas (list[dict]): List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys, see
:class:`mmdet.datasets.pipelines.Collect`.
kwargs (dict): Specific to concrete implementation.

Returns:
data_groups (dict): a dict that data_batch splited by tags,
such as 'sup', 'unsup_teacher', and 'unsup_student'.
"""

# only stack img in the batch
def fuse_list(obj_list, obj):
return torch.stack(obj_list) if isinstance(obj,
torch.Tensor) else obj_list

# select data with tag from data_batch
def select_group(data_batch, current_tag):
group_flag = [tag == current_tag for tag in data_batch['tag']]
return {
k: fuse_list([vv for vv, gf in zip(v, group_flag) if gf], v)
for k, v in data_batch.items()
}

kwargs.update({'img': img, 'img_metas': img_metas})
kwargs.update({'tag': [meta['tag'] for meta in img_metas]})
tags = list(set(kwargs['tag']))
data_groups = {tag: select_group(kwargs, tag) for tag in tags}
for tag, group in data_groups.items():
group.pop('tag')
return data_groups
2 changes: 1 addition & 1 deletion requirements/docs.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
docutils==0.16.0
-e git+https://github.com/open-mmlab/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
myst-parser
-e git+https://github.com/open-mmlab/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
sphinx==4.0.2
sphinx-copybutton
sphinx_markdown_tables
Expand Down
95 changes: 95 additions & 0 deletions tests/test_utils/test_split_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from copy import deepcopy

import mmcv
import numpy as np
import torch

from mmdet.utils import split_batch


def test_split_batch():
img_root = osp.join(osp.dirname(__file__), '../data/color.jpg')
img = mmcv.imread(img_root, 'color')
h, w, _ = img.shape
gt_bboxes = np.array([[0.2 * w, 0.2 * h, 0.4 * w, 0.4 * h],
[0.6 * w, 0.6 * h, 0.8 * w, 0.8 * h]],
dtype=np.float32)
gt_lables = np.ones(gt_bboxes.shape[0], dtype=np.int64)

img = torch.tensor(img).permute(2, 0, 1)
meta = dict()
meta['filename'] = img_root
meta['ori_shape'] = img.shape
meta['img_shape'] = img.shape
meta['img_norm_cfg'] = {
'mean': np.array([103.53, 116.28, 123.675], dtype=np.float32),
'std': np.array([1., 1., 1.], dtype=np.float32),
'to_rgb': False
}
meta['pad_shape'] = img.shape
# For example, tag include sup, unsup_teacher and unsup_student,
# in order to distinguish the difference between the three groups of data,
# the scale_factor of sup is [0.5, 0.5, 0.5, 0.5]
# the scale_factor of unsup_teacher is [1.0, 1.0, 1.0, 1.0]
# the scale_factor of unsup_student is [2.0, 2.0, 2.0, 2.0]
imgs = img.unsqueeze(0).repeat(9, 1, 1, 1)
img_metas = []
tags = [
'sup', 'unsup_teacher', 'unsup_student', 'unsup_teacher',
'unsup_student', 'unsup_teacher', 'unsup_student', 'unsup_teacher',
'unsup_student'
]
for tag in tags:
img_meta = deepcopy(meta)
if tag == 'sup':
img_meta['scale_factor'] = [0.5, 0.5, 0.5, 0.5]
img_meta['tag'] = 'sup'
elif tag == 'unsup_teacher':
img_meta['scale_factor'] = [1.0, 1.0, 1.0, 1.0]
img_meta['tag'] = 'unsup_teacher'
elif tag == 'unsup_student':
img_meta['scale_factor'] = [2.0, 2.0, 2.0, 2.0]
img_meta['tag'] = 'unsup_student'
else:
continue
img_metas.append(img_meta)
kwargs = dict()
kwargs['gt_bboxes'] = [torch.tensor(gt_bboxes)] + [torch.zeros(0, 4)] * 8
kwargs['gt_lables'] = [torch.tensor(gt_lables)] + [torch.zeros(0, )] * 8
data_groups = split_batch(imgs, img_metas, kwargs)
assert set(data_groups.keys()) == set(tags)
assert data_groups['sup']['img'].shape == (1, 3, h, w)
assert data_groups['unsup_teacher']['img'].shape == (4, 3, h, w)
assert data_groups['unsup_student']['img'].shape == (4, 3, h, w)
# the scale_factor of sup is [0.5, 0.5, 0.5, 0.5]
assert data_groups['sup']['img_metas'][0]['scale_factor'] == [
0.5, 0.5, 0.5, 0.5
]
# the scale_factor of unsup_teacher is [1.0, 1.0, 1.0, 1.0]
assert data_groups['unsup_teacher']['img_metas'][0]['scale_factor'] == [
1.0, 1.0, 1.0, 1.0
]
assert data_groups['unsup_teacher']['img_metas'][1]['scale_factor'] == [
1.0, 1.0, 1.0, 1.0
]
assert data_groups['unsup_teacher']['img_metas'][2]['scale_factor'] == [
1.0, 1.0, 1.0, 1.0
]
assert data_groups['unsup_teacher']['img_metas'][3]['scale_factor'] == [
1.0, 1.0, 1.0, 1.0
]
# the scale_factor of unsup_student is [2.0, 2.0, 2.0, 2.0]
assert data_groups['unsup_student']['img_metas'][0]['scale_factor'] == [
2.0, 2.0, 2.0, 2.0
]
assert data_groups['unsup_student']['img_metas'][1]['scale_factor'] == [
2.0, 2.0, 2.0, 2.0
]
assert data_groups['unsup_student']['img_metas'][2]['scale_factor'] == [
2.0, 2.0, 2.0, 2.0
]
assert data_groups['unsup_student']['img_metas'][3]['scale_factor'] == [
2.0, 2.0, 2.0, 2.0
]