Skip to content

Commit

Permalink
[Feature] Support resizemix. (#676)
Browse files Browse the repository at this point in the history
* add resizemix

* skip torch.__version__ < 1.7.0

* Update mmcls/models/utils/augment/resizemix.py

Co-authored-by: Ma Zerun <mzr1996@163.com>

* Update mmcls/models/utils/augment/resizemix.py

Co-authored-by: Ma Zerun <mzr1996@163.com>

* resize -> F.interpolate

* fix docs

* fix test

* add Copyright

* add argument interpolation

Co-authored-by: Ma Zerun <mzr1996@163.com>
  • Loading branch information
okotaku and mzr1996 committed Mar 7, 2022
1 parent 2037260 commit c1534f9
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 1 deletion.
4 changes: 3 additions & 1 deletion mmcls/models/utils/augment/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,7 @@
from .cutmix import BatchCutMixLayer
from .identity import Identity
from .mixup import BatchMixupLayer
from .resizemix import BatchResizeMixLayer

__all__ = ('Augments', 'BatchCutMixLayer', 'Identity', 'BatchMixupLayer')
__all__ = ('Augments', 'BatchCutMixLayer', 'Identity', 'BatchMixupLayer',
'BatchResizeMixLayer')
89 changes: 89 additions & 0 deletions mmcls/models/utils/augment/resizemix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.nn.functional as F

from mmcls.models.utils.augment.builder import AUGMENT
from .cutmix import BatchCutMixLayer
from .utils import one_hot_encoding


@AUGMENT.register_module(name='BatchResizeMix')
class BatchResizeMixLayer(BatchCutMixLayer):
r"""ResizeMix Random Paste layer for batch ResizeMix.
The ResizeMix will resize an image to a small patch and paste it on another
image. More details can be found in `ResizeMix: Mixing Data with Preserved
Object Information and True Labels <https://arxiv.org/abs/2012.11101>`_
Args:
alpha (float): Parameters for Beta distribution. Positive(>0)
num_classes (int): The number of classes.
lam_min(float): The minimum value of lam. Defaults to 0.1.
lam_max(float): The maximum value of lam. Defaults to 0.8.
interpolation (str): algorithm used for upsampling:
'nearest' | 'linear' | 'bilinear' | 'bicubic' | 'trilinear' | 'area'.
Default to 'bilinear'.
prob (float): mix probability. It should be in range [0, 1].
Default to 1.0.
cutmix_minmax (List[float], optional): cutmix min/max image ratio.
(as percent of image size). When cutmix_minmax is not None, we
generate cutmix bounding-box using cutmix_minmax instead of alpha
correct_lam (bool): Whether to apply lambda correction when cutmix bbox
clipped by image borders. Default to True
**kwargs: Any other parameters accpeted by :class:`BatchCutMixLayer`.
Note:
The :math:`\lambda` (``lam``) is the mixing ratio. It's a random
variable which follows :math:`Beta(\alpha, \alpha)` and is mapped
to the range [``lam_min``, ``lam_max``].
.. math::
\lambda = \frac{Beta(\alpha, \alpha)}
{\lambda_{max} - \lambda_{min}} + \lambda_{min}
And the resize ratio of source images is calculated by :math:`\lambda`:
.. math::
\text{ratio} = \sqrt{1-lam}
"""

def __init__(self,
alpha,
num_classes,
lam_min: float = 0.1,
lam_max: float = 0.8,
interpolation='bilinear',
prob=1.0,
cutmix_minmax=None,
correct_lam=True,
**kwargs):
super(BatchResizeMixLayer, self).__init__(
alpha=alpha,
num_classes=num_classes,
prob=prob,
cutmix_minmax=cutmix_minmax,
correct_lam=correct_lam,
**kwargs)
self.lam_min = lam_min
self.lam_max = lam_max
self.interpolation = interpolation

def cutmix(self, img, gt_label):
one_hot_gt_label = one_hot_encoding(gt_label, self.num_classes)

lam = np.random.beta(self.alpha, self.alpha)
lam = lam * (self.lam_max - self.lam_min) + self.lam_min
batch_size = img.size(0)
index = torch.randperm(batch_size)

(bby1, bby2, bbx1,
bbx2), lam = self.cutmix_bbox_and_lam(img.shape, lam)

img[:, :, bby1:bby2, bbx1:bbx2] = F.interpolate(
img[index],
size=(bby2 - bby1, bbx2 - bbx1),
mode=self.interpolation)
mixed_gt_label = lam * one_hot_gt_label + (
1 - lam) * one_hot_gt_label[index, :]
return img, mixed_gt_label
9 changes: 9 additions & 0 deletions tests/test_models/test_utils/test_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
dict(type='BatchCutMix', alpha=1., prob=1.),
dict(type='BatchMixup', alpha=1., prob=1.),
dict(type='Identity', prob=1.),
dict(type='BatchResizeMix', alpha=1., prob=1.)
]


Expand All @@ -29,6 +30,14 @@ def test_augments():
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
assert mixed_labels.shape == torch.Size((4, 10))

# Test resizemix
augments_cfg = dict(
type='BatchResizeMix', alpha=1., num_classes=10, prob=1.)
augs = Augments(augments_cfg)
mixed_imgs, mixed_labels = augs(imgs, labels)
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
assert mixed_labels.shape == torch.Size((4, 10))

# Test cutmixup
augments_cfg = [
dict(type='BatchCutMix', alpha=1., num_classes=10, prob=0.5),
Expand Down

0 comments on commit c1534f9

Please sign in to comment.