Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support Mixup and Cutmix for Recognizers. #681

Merged
merged 27 commits into from Mar 24, 2021
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
55be548
add todo list
irvingzhang0512 Mar 5, 2021
bd0fd7e
codes of mixup/cutmix/register/recognizers
irvingzhang0512 Mar 5, 2021
a71aa6e
add unittest
irvingzhang0512 Mar 5, 2021
cab1eda
add demo config
irvingzhang0512 Mar 5, 2021
5952d58
fix unittest
irvingzhang0512 Mar 5, 2021
54f651e
remove toto list
irvingzhang0512 Mar 5, 2021
1b5bdf9
update changelog
irvingzhang0512 Mar 5, 2021
ff68dd9
Merge branch 'master' into mixup
irvingzhang0512 Mar 5, 2021
7a38fea
fix unittest and training bug
irvingzhang0512 Mar 5, 2021
bf13717
fix
irvingzhang0512 Mar 5, 2021
86c06b4
fix unittest
irvingzhang0512 Mar 5, 2021
0618906
add todo
irvingzhang0512 Mar 5, 2021
3a7bacf
remove useless codes
irvingzhang0512 Mar 6, 2021
1b2cac9
update comments
irvingzhang0512 Mar 6, 2021
eea8ee4
update docs
irvingzhang0512 Mar 6, 2021
cd3e437
fix
irvingzhang0512 Mar 7, 2021
1f887ca
fix a bug
irvingzhang0512 Mar 12, 2021
f31eaf5
update configs
irvingzhang0512 Mar 16, 2021
8e3e48c
update sthv1 training results
irvingzhang0512 Mar 16, 2021
0c8af1f
Merge branch 'master' into mixup
irvingzhang0512 Mar 16, 2021
95ce916
add tsn config and modify default alpha value
irvingzhang0512 Mar 16, 2021
0735ef7
fix lint
irvingzhang0512 Mar 16, 2021
60a6b15
add unittest
irvingzhang0512 Mar 16, 2021
5bf768b
fix tin sthv2 config
irvingzhang0512 Mar 18, 2021
e79c5da
update links
kennymckormick Mar 18, 2021
60915ea
remove useless docs
irvingzhang0512 Mar 18, 2021
749de3e
Merge branch 'master' into mixup
irvingzhang0512 Mar 18, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
@@ -0,0 +1,110 @@
_base_ = [
'../../_base_/schedules/sgd_100e.py', '../../_base_/default_runtime.py'
]

# model settings
model = dict(
type='Recognizer2D',
backbone=dict(
type='ResNet',
pretrained='torchvision://resnet50',
depth=50,
norm_eval=False),
cls_head=dict(
type='TSNHead',
num_classes=400,
in_channels=2048,
spatial_type='avg',
consensus=dict(type='AvgConsensus', dim=1),
dropout_ratio=0.4,
init_std=0.01),
# model training and testing settings
# train_cfg=dict(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

# blending=dict(type="CutmixBlending", num_classes=400, alpha=.2)),
train_cfg=dict(
blending=dict(type='MixupBlending', num_classes=400, alpha=.2)),
test_cfg=dict(average_clips=None))

# dataset settings
dataset_type = 'VideoDataset'
data_root = 'data/kinetics400/videos_train'
data_root_val = 'data/kinetics400/videos_val'
ann_file_train = 'data/kinetics400/kinetics400_train_list_videos.txt'
ann_file_val = 'data/kinetics400/kinetics400_val_list_videos.txt'
ann_file_test = 'data/kinetics400/kinetics400_val_list_videos.txt'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False)
train_pipeline = [
dict(type='DecordInit'),
dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=8),
dict(type='DecordDecode'),
dict(
type='MultiScaleCrop',
input_size=224,
scales=(1, 0.875, 0.75, 0.66),
random_crop=False,
max_wh_scale_gap=1),
dict(type='Resize', scale=(224, 224), keep_ratio=False),
dict(type='Flip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs', 'label'])
]
val_pipeline = [
dict(type='DecordInit'),
dict(
type='SampleFrames',
clip_len=1,
frame_interval=1,
num_clips=8,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='CenterCrop', crop_size=224),
dict(type='Flip', flip_ratio=0),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
test_pipeline = [
dict(type='DecordInit'),
dict(
type='SampleFrames',
clip_len=1,
frame_interval=1,
num_clips=25,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='ThreeCrop', crop_size=256),
dict(type='Flip', flip_ratio=0),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
data = dict(
videos_per_gpu=32,
workers_per_gpu=4,
train=dict(
type=dataset_type,
ann_file=ann_file_train,
data_prefix=data_root,
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=ann_file_val,
data_prefix=data_root_val,
pipeline=val_pipeline),
test=dict(
type=dataset_type,
ann_file=ann_file_test,
data_prefix=data_root_val,
pipeline=test_pipeline))
evaluation = dict(
interval=5, metrics=['top_k_accuracy', 'mean_class_accuracy'])

# runtime settings
work_dir = './work_dirs/tsn_r50_video_mixup_1x1x8_100e_kinetics400_rgb/'
1 change: 1 addition & 0 deletions docs/changelog.md
Expand Up @@ -7,6 +7,7 @@
**New Features**

- Support LFB [#553](https://github.com/open-mmlab/mmaction2/pull/553)
- Support Mixup and Cutmix for recognizers [#681](https://github.com/open-mmlab/mmaction2/pull/681)

**Improvements**

Expand Down
5 changes: 4 additions & 1 deletion mmaction/datasets/__init__.py
Expand Up @@ -4,6 +4,8 @@
from .audio_visual_dataset import AudioVisualDataset
from .ava_dataset import AVADataset
from .base import BaseDataset
from .blending_utils import (BaseMiniBatchBlending, CutmixBlending,
MixupBlending)
from .builder import build_dataloader, build_dataset
from .dataset_wrappers import RepeatDataset
from .hvu_dataset import HVUDataset
Expand All @@ -17,5 +19,6 @@
'VideoDataset', 'build_dataloader', 'build_dataset', 'RepeatDataset',
'RawframeDataset', 'BaseDataset', 'ActivityNetDataset', 'SSNDataset',
'HVUDataset', 'AudioDataset', 'AudioFeatureDataset', 'ImageDataset',
'RawVideoDataset', 'AVADataset', 'AudioVisualDataset'
'RawVideoDataset', 'AVADataset', 'AudioVisualDataset',
'BaseMiniBatchBlending', 'CutmixBlending', 'MixupBlending'
]
142 changes: 142 additions & 0 deletions mmaction/datasets/blending_utils.py
@@ -0,0 +1,142 @@
from abc import ABCMeta, abstractmethod

import torch
import torch.nn.functional as F
from torch.distributions.beta import Beta

from .registry import BLENDINGS

__all__ = ['BaseMiniBatchBlending', 'MixupBlending', 'CutmixBlending']


class BaseMiniBatchBlending(metaclass=ABCMeta):
"""Base class for Image Aliasing."""

def __init__(self, num_classes):
self.num_classes = num_classes

@abstractmethod
def do_blending(self, imgs, label, **kwargs):
pass

def __call__(self, imgs, label, **kwargs):
"""Blending data in a mini-batch.

Images are float tensors with the shape of (B, N, C, H, W) for 2D
recognizers or (B, N, C, T, H, W) for 3D recognizers.

Besides, labels are converted from hard labels to soft labels.
Hard labels are integer tensors with the shape of (B, 1) and all of the
elements are in the range [0, num_classes - 1].
Soft labels (probablity distribution over classes) are float tensors
with the shape of (B, 1, num_classes) and all of the elements are in
the range [0, 1].

Args:
imgs (torch.Tensor): Model input images, float tensor with the
shape of (B, N, C, H, W) or (B, N, C, T, H, W).
label (torch.Tensor): Hard labels, integer tensor with the shape
of (B, 1) and all elements are in range [0, num_classes).
kwargs (dict, optional): Other keyword argument to be used to
blending imgs and labels in a mini-batch.

Returns:
mixed_imgs (torch.Tensor): Blending images, float tensor with the
same shape of the input imgs.
mixed_label (torch.Tensor): Blended soft labels, float tensor with
the shape of (B, 1, num_classes) and all elements are in range
[0, 1].
"""
one_hot_label = F.one_hot(label, num_classes=self.num_classes)

mixed_imgs, mixed_label = self.do_blending(imgs, one_hot_label,
**kwargs)

return mixed_imgs, mixed_label


@BLENDINGS.register_module()
class MixupBlending(BaseMiniBatchBlending):
"""Implementing Mixup in a mini-batch.

This module is proposed in `mixup: Beyond Empirical Risk Minimization
<https://arxiv.org/abs/1710.09412>`_.
Code Reference https://github.com/open-mmlab/mmclassification/blob/master/mmcls/models/utils/mixup.py # noqa

Args:
num_classes (int): The number of classes.
alpha (float): Parameters for Beta distribution.
"""

def __init__(self, num_classes, alpha=1.):
irvingzhang0512 marked this conversation as resolved.
Show resolved Hide resolved
super().__init__(num_classes=num_classes)
self.beta = Beta(alpha, alpha)

def do_blending(self, imgs, label, **kwargs):
"""Blending images with mixup."""
assert len(kwargs) == 0, f'unexpected kwargs for mixup {kwargs}'

lam = self.beta.sample()
batch_size = imgs.size(0)
rand_index = torch.randperm(batch_size)

mixed_imgs = lam * imgs + (1 - lam) * imgs[rand_index, :]
mixed_label = lam * label + (1 - lam) * label[rand_index, :]

return mixed_imgs, mixed_label


@BLENDINGS.register_module()
class CutmixBlending(BaseMiniBatchBlending):
"""Implementing Cutmix in a mini-batch.

This module is proposed in `CutMix: Regularization Strategy to Train Strong
Classifiers with Localizable Features <https://arxiv.org/abs/1905.04899>`_.
Code Reference https://github.com/clovaai/CutMix-PyTorch

Args:
num_classes (int): The number of classes.
alpha (float): Parameters for Beta distribution.
"""

def __init__(self, num_classes, alpha=1.):
super().__init__(num_classes=num_classes)
self.beta = Beta(alpha, alpha)

@staticmethod
def rand_bbox(img_size, lam):
"""Generate a random boudning box."""
w = img_size[-1]
h = img_size[-2]
cut_rat = torch.sqrt(1. - lam)
cut_w = torch.tensor(int(w * cut_rat))
cut_h = torch.tensor(int(h * cut_rat))

# uniform
cx = torch.randint(w, (1, ))[0]
cy = torch.randint(h, (1, ))[0]

bbx1 = torch.clamp(cx - cut_w // 2, 0, w)
bby1 = torch.clamp(cy - cut_h // 2, 0, h)
bbx2 = torch.clamp(cx + cut_w // 2, 0, w)
bby2 = torch.clamp(cy + cut_h // 2, 0, h)

return bbx1, bby1, bbx2, bby2

def do_blending(self, imgs, label, **kwargs):
"""Blending images with cutmix."""
assert len(kwargs) == 0, f'unexpected kwargs for cutmix {kwargs}'

batch_size = imgs.size(0)
rand_index = torch.randperm(batch_size)
lam = self.beta.sample()

bbx1, bby1, bbx2, bby2 = self.rand_bbox(imgs.size(), lam)
imgs[:, ..., bby1:bby2, bbx1:bbx2] = imgs[rand_index, ..., bby1:bby2,
bbx1:bbx2]
lam = 1 - (1.0 * (bbx2 - bbx1) * (bby2 - bby1) /
(imgs.size()[-1] * imgs.size()[-2]))

label = lam * label + (1 - lam) * label[rand_index, :]

return imgs, label
1 change: 1 addition & 0 deletions mmaction/datasets/registry.py
Expand Up @@ -2,3 +2,4 @@

DATASETS = Registry('dataset')
PIPELINES = Registry('pipeline')
BLENDINGS = Registry('blending')
5 changes: 3 additions & 2 deletions mmaction/models/heads/base.py
Expand Up @@ -80,15 +80,16 @@ def loss(self, cls_score, labels, **kwargs):
if labels.shape == torch.Size([]):
labels = labels.unsqueeze(0)

if not self.multi_class:
if not self.multi_class and cls_score.size() != labels.size():
# TODO: Whether to show top1/top5 accuracy when using mixup/cutmix
top_k_acc = top_k_accuracy(cls_score.detach().cpu().numpy(),
labels.detach().cpu().numpy(), (1, 5))
losses['top1_acc'] = torch.tensor(
top_k_acc[0], device=cls_score.device)
losses['top5_acc'] = torch.tensor(
top_k_acc[1], device=cls_score.device)

elif self.label_smooth_eps != 0:
elif self.multi_class and self.label_smooth_eps != 0:
labels = ((1 - self.label_smooth_eps) * labels +
self.label_smooth_eps / self.num_classes)

Expand Down
9 changes: 9 additions & 0 deletions mmaction/models/recognizers/base.py
Expand Up @@ -53,6 +53,13 @@ def __init__(self,
self.max_testing_views = test_cfg['max_testing_views']
assert isinstance(self.max_testing_views, int)

# mini-batch blending, e.g. mixup, cutmix, etc.
self.blending = None
if train_cfg is not None and 'blending' in train_cfg:
from mmcv.utils import build_from_cfg
from ...datasets.registry import BLENDINGS
self.blending = build_from_cfg(train_cfg['blending'], BLENDINGS)

self.init_weights()

self.fp16_enabled = False
Expand Down Expand Up @@ -171,6 +178,8 @@ def forward(self, imgs, label=None, return_loss=True, **kwargs):
if return_loss:
if label is None:
raise ValueError('Label should not be None.')
if self.blending is not None:
imgs, label = self.blending(imgs, label)
return self.forward_train(imgs, label, **kwargs)

return self.forward_test(imgs, **kwargs)
Expand Down
41 changes: 41 additions & 0 deletions tests/test_data/test_blending.py
@@ -0,0 +1,41 @@
import torch

from mmaction.datasets import CutmixBlending, MixupBlending


def test_mixup():
alpha = 0.2
num_classes = 10
label = torch.randint(0, num_classes, (4, ))
mixup = MixupBlending(num_classes, alpha)

# NCHW imgs
imgs = torch.randn(4, 4, 3, 32, 32)
mixed_imgs, mixed_label = mixup(imgs, label)
assert mixed_imgs.shape == torch.Size((4, 4, 3, 32, 32))
assert mixed_label.shape == torch.Size((4, num_classes))

# NCTHW imgs
imgs = torch.randn(4, 4, 2, 3, 32, 32)
mixed_imgs, mixed_label = mixup(imgs, label)
assert mixed_imgs.shape == torch.Size((4, 4, 2, 3, 32, 32))
assert mixed_label.shape == torch.Size((4, num_classes))


def test_cutmix():
alpha = 0.2
num_classes = 10
label = torch.randint(0, num_classes, (4, ))
mixup = CutmixBlending(num_classes, alpha)

# NCHW imgs
imgs = torch.randn(4, 4, 3, 32, 32)
mixed_imgs, mixed_label = mixup(imgs, label)
assert mixed_imgs.shape == torch.Size((4, 4, 3, 32, 32))
assert mixed_label.shape == torch.Size((4, num_classes))

# NCTHW imgs
imgs = torch.randn(4, 4, 2, 3, 32, 32)
mixed_imgs, mixed_label = mixup(imgs, label)
assert mixed_imgs.shape == torch.Size((4, 4, 2, 3, 32, 32))
assert mixed_label.shape == torch.Size((4, num_classes))
12 changes: 12 additions & 0 deletions tests/test_models/test_recognizers/test_recognizer2d.py
Expand Up @@ -30,6 +30,18 @@ def test_tsn():
for one_img in img_list:
recognizer(one_img, gradcam=True)

# test mixup forward
config = get_recognizer_cfg(
'tsn/tsn_r50_video_mixup_1x1x8_100e_kinetics400_rgb.py')
config.model['backbone']['pretrained'] = None
recognizer = build_recognizer(config.model)
input_shape = (2, 8, 3, 32, 32)
demo_inputs = generate_recognizer_demo_inputs(input_shape)
imgs = demo_inputs['imgs']
gt_labels = demo_inputs['gt_labels']
losses = recognizer(imgs, gt_labels)
assert isinstance(losses, dict)


def test_tsm():
config = get_recognizer_cfg('tsm/tsm_r50_1x1x8_50e_kinetics400_rgb.py')
Expand Down