From fbe62f6984bbeb71625c0c345f412830484b2443 Mon Sep 17 00:00:00 2001 From: irvingzhang0512 Date: Wed, 24 Mar 2021 10:50:53 +0800 Subject: [PATCH] [Feature] Support Mixup and Cutmix for Recognizers. (#681) * add todo list * codes of mixup/cutmix/register/recognizers * add unittest * add demo config * fix unittest * remove toto list * update changelog * fix unittest and training bug * fix * fix unittest * add todo * remove useless codes * update comments * update docs * fix * fix a bug * update configs * update sthv1 training results * add tsn config and modify default alpha value * fix lint * add unittest * fix tin sthv2 config * update links * remove useless docs Co-authored-by: Kenny --- .../tin/tin_r50_1x1x8_40e_sthv2_rgb.py | 3 - configs/recognition/tsm/README.md | 9 ++ .../tsm/tsm_r50_cutmix_1x1x8_50e_sthv1_rgb.py | 114 ++++++++++++++ .../tsm/tsm_r50_mixup_1x1x8_50e_sthv1_rgb.py | 114 ++++++++++++++ ..._video_mixup_1x1x8_100e_kinetics400_rgb.py | 110 ++++++++++++++ docs/changelog.md | 1 + mmaction/datasets/__init__.py | 5 +- mmaction/datasets/blending_utils.py | 142 ++++++++++++++++++ mmaction/datasets/registry.py | 1 + mmaction/models/heads/base.py | 10 +- mmaction/models/recognizers/base.py | 9 ++ tests/test_data/test_blending.py | 41 +++++ .../test_common_modules/test_base_head.py | 39 +++++ .../test_recognizers/test_recognizer2d.py | 12 ++ 14 files changed, 604 insertions(+), 6 deletions(-) create mode 100644 configs/recognition/tsm/tsm_r50_cutmix_1x1x8_50e_sthv1_rgb.py create mode 100644 configs/recognition/tsm/tsm_r50_mixup_1x1x8_50e_sthv1_rgb.py create mode 100644 configs/recognition/tsn/tsn_r50_video_mixup_1x1x8_100e_kinetics400_rgb.py create mode 100644 mmaction/datasets/blending_utils.py create mode 100644 tests/test_data/test_blending.py diff --git a/configs/recognition/tin/tin_r50_1x1x8_40e_sthv2_rgb.py b/configs/recognition/tin/tin_r50_1x1x8_40e_sthv2_rgb.py index f2320e2306..8d1a93d561 100644 --- a/configs/recognition/tin/tin_r50_1x1x8_40e_sthv2_rgb.py +++ b/configs/recognition/tin/tin_r50_1x1x8_40e_sthv2_rgb.py @@ -65,19 +65,16 @@ type=dataset_type, ann_file=ann_file_train, data_prefix=data_root, - filename_tmpl='{:05}.jpg', pipeline=train_pipeline), val=dict( type=dataset_type, ann_file=ann_file_val, data_prefix=data_root_val, - filename_tmpl='{:05}.jpg', pipeline=val_pipeline), test=dict( type=dataset_type, ann_file=ann_file_test, data_prefix=data_root_val, - filename_tmpl='{:05}.jpg', pipeline=test_pipeline)) evaluation = dict( interval=2, metrics=['top_k_accuracy', 'mean_class_accuracy']) diff --git a/configs/recognition/tsm/README.md b/configs/recognition/tsm/README.md index 219c6c89d4..e0849ad6be 100644 --- a/configs/recognition/tsm/README.md +++ b/configs/recognition/tsm/README.md @@ -60,6 +60,13 @@ |[tsm_r50_1x1x16_50e_sthv2_rgb](/configs/recognition/tsm/tsm_r50_1x1x16_50e_sthv2_rgb.py) |height 240|8| ResNet50| ImageNet |59.93 / 62.04|86.10 / 87.35|[58.90 / 60.98](https://github.com/mit-han-lab/temporal-shift-module/tree/8d53d6fda40bea2f1b37a6095279c4b454d672bd#training)|[85.29 / 86.60](https://github.com/mit-han-lab/temporal-shift-module/tree/8d53d6fda40bea2f1b37a6095279c4b454d672bd#training)| 10400| [ckpt](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_r50_1x1x16_50e_sthv2_rgb/tsm_r50_1x1x16_50e_sthv2_rgb_20201010-16469c6f.pth) | [log](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_r50_1x1x16_50e_sthv2_rgb/20201010_224215.log)| [json](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_r50_1x1x16_50e_sthv2_rgb/20201010_224215.log.json)| |[tsm_r101_1x1x8_50e_sthv2_rgb](/configs/recognition/tsm/tsm_r101_1x1x8_50e_sthv2_rgb.py) |height 240|8| ResNet101 | ImageNet|58.59 / 61.51|85.07 / 86.90|[58.89 / 61.36](https://github.com/mit-han-lab/temporal-shift-module/tree/8d53d6fda40bea2f1b37a6095279c4b454d672bd#training)|[85.14 / 87.00](https://github.com/mit-han-lab/temporal-shift-module/tree/8d53d6fda40bea2f1b37a6095279c4b454d672bd#training)| 9784 | [ckpt](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_r101_1x1x8_50e_sthv2_rgb/tsm_r101_1x1x8_50e_sthv2_rgb_20201010-98cdedb8.pth) | [log](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_r101_1x1x8_50e_sthv2_rgb/20201010_224100.log)| [json](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_r101_1x1x8_50e_sthv2_rgb/20201010_224100.log.json)| +### MixUp & CutMix on Something-Something V1 + +| config | resolution | gpus | backbone | pretrain | top1 acc (efficient/accurate) | top5 acc (efficient/accurate) | delta top1 acc (efficient/accurate) | delta top5 acc (efficient/accurate) | ckpt | log | json | +| :----------------------------------------------------------- | :--------: | :--: | :------: | :------: | :---------------------------: | :---------------------------: | :---------------------------------: | :---------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | +| [tsm_r50_mixup_1x1x8_50e_sthv1_rgb](/configs/recognition/tsm/tsm_r50_mixup_1x1x8_50e_sthv1_rgb.py) | height 100 | 8 | ResNet50 | ImageNet | 46.35 / 48.49 | 75.07 / 76.88 | +0.77 / +0.79 | +0.05 / +0.70 | [ckpt](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_r50_mixup_1x1x8_50e_sthv1_rgb/tsm_r50_mixup_1x1x8_50e_sthv1_rgb-9eca48e5.pth) | [log](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_r50_mixup_1x1x8_50e_sthv1_rgb/tsm_r50_mixup_1x1x8_50e_sthv1_rgb.log) | [json](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_r50_mixup_1x1x8_50e_sthv1_rgb/tsm_r50_mixup_1x1x8_50e_sthv1_rgb.json) | +| [tsm_r50_cutmix_1x1x8_50e_sthv1_rgb](/configs/recognition/tsm/tsm_r50_cutmix_1x1x8_50e_sthv1_rgb.py) | height 100 | 8 | ResNet50 | ImageNet | 45.92 / 47.46 | 75.23 / 76.71 | +0.34 / -0.24 | +0.21 / +0.59 | [ckpt](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_r50_cutmix_1x1x8_50e_sthv1_rgb/tsm_r50_cutmix_1x1x8_50e_sthv1_rgb-34934615.pth) | [log](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_r50_cutmix_1x1x8_50e_sthv1_rgb/tsm_r50_cutmix_1x1x8_50e_sthv1_rgb.log) | [json](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_r50_cutmix_1x1x8_50e_sthv1_rgb/tsm_r50_cutmix_1x1x8_50e_sthv1_rgb.json) | + Notes: 1. The **gpus** indicates the number of gpu we used to get the checkpoint. It is noteworthy that the configs we provide are used for 8 gpus as default. @@ -94,6 +101,8 @@ test_pipeline = [ For more details on data preparation, you can refer to Kinetics400, Something-Something V1 and Something-Something V2 in [Data Preparation](/docs/data_preparation.md). +5. When applying Mixup and CutMix, we use the hyper parameter `alpha=0.2`. + ## Train You can use the following command to train a model. diff --git a/configs/recognition/tsm/tsm_r50_cutmix_1x1x8_50e_sthv1_rgb.py b/configs/recognition/tsm/tsm_r50_cutmix_1x1x8_50e_sthv1_rgb.py new file mode 100644 index 0000000000..abf672adc2 --- /dev/null +++ b/configs/recognition/tsm/tsm_r50_cutmix_1x1x8_50e_sthv1_rgb.py @@ -0,0 +1,114 @@ +_base_ = [ + '../../_base_/schedules/sgd_tsm_50e.py', '../../_base_/default_runtime.py' +] + +# model settings +# model settings# model settings +model = dict( + type='Recognizer2D', + backbone=dict( + type='ResNetTSM', + pretrained='torchvision://resnet50', + depth=50, + norm_eval=False, + shift_div=8), + cls_head=dict( + type='TSMHead', + num_classes=174, + in_channels=2048, + spatial_type='avg', + consensus=dict(type='AvgConsensus', dim=1), + dropout_ratio=0.5, + init_std=0.001, + is_shift=True), + # model training and testing settings + train_cfg=dict( + blending=dict(type='CutmixBlending', num_classes=174, alpha=.2)), + test_cfg=dict(average_clips='prob')) + +# dataset settings +dataset_type = 'RawframeDataset' +data_root = 'data/sthv1/rawframes' +data_root_val = 'data/sthv1/rawframes' +ann_file_train = 'data/sthv1/sthv1_train_list_rawframes.txt' +ann_file_val = 'data/sthv1/sthv1_val_list_rawframes.txt' +ann_file_test = 'data/sthv1/sthv1_val_list_rawframes.txt' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False) +train_pipeline = [ + dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=8), + dict(type='RawFrameDecode'), + dict(type='Resize', scale=(-1, 256)), + dict( + type='MultiScaleCrop', + input_size=224, + scales=(1, 0.875, 0.75, 0.66), + random_crop=False, + max_wh_scale_gap=1, + num_fixed_crops=13), + dict(type='Resize', scale=(224, 224), keep_ratio=False), + dict(type='Normalize', **img_norm_cfg), + dict(type='FormatShape', input_format='NCHW'), + dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), + dict(type='ToTensor', keys=['imgs', 'label']) +] +val_pipeline = [ + dict( + type='SampleFrames', + clip_len=1, + frame_interval=1, + num_clips=8, + test_mode=True), + dict(type='RawFrameDecode'), + dict(type='Resize', scale=(-1, 256)), + dict(type='CenterCrop', crop_size=224), + dict(type='Normalize', **img_norm_cfg), + dict(type='FormatShape', input_format='NCHW'), + dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), + dict(type='ToTensor', keys=['imgs']) +] +test_pipeline = [ + dict( + type='SampleFrames', + clip_len=1, + frame_interval=1, + num_clips=8, + twice_sample=True, + test_mode=True), + dict(type='RawFrameDecode'), + dict(type='Resize', scale=(-1, 256)), + dict(type='ThreeCrop', crop_size=256), + dict(type='Normalize', **img_norm_cfg), + dict(type='FormatShape', input_format='NCHW'), + dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), + dict(type='ToTensor', keys=['imgs']) +] +data = dict( + videos_per_gpu=8, + workers_per_gpu=4, + train=dict( + type=dataset_type, + ann_file=ann_file_train, + data_prefix=data_root, + filename_tmpl='{:05}.jpg', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + ann_file=ann_file_val, + data_prefix=data_root_val, + filename_tmpl='{:05}.jpg', + pipeline=val_pipeline), + test=dict( + type=dataset_type, + ann_file=ann_file_test, + data_prefix=data_root_val, + filename_tmpl='{:05}.jpg', + pipeline=test_pipeline)) +evaluation = dict( + interval=2, metrics=['top_k_accuracy', 'mean_class_accuracy']) + +# optimizer +optimizer = dict(weight_decay=0.0005) + +# runtime settings +work_dir = './work_dirs/tsm_r50_cutmix_1x1x8_50e_sthv1_rgb/' diff --git a/configs/recognition/tsm/tsm_r50_mixup_1x1x8_50e_sthv1_rgb.py b/configs/recognition/tsm/tsm_r50_mixup_1x1x8_50e_sthv1_rgb.py new file mode 100644 index 0000000000..3767fa45d9 --- /dev/null +++ b/configs/recognition/tsm/tsm_r50_mixup_1x1x8_50e_sthv1_rgb.py @@ -0,0 +1,114 @@ +_base_ = [ + '../../_base_/schedules/sgd_tsm_50e.py', '../../_base_/default_runtime.py' +] + +# model settings +# model settings# model settings +model = dict( + type='Recognizer2D', + backbone=dict( + type='ResNetTSM', + pretrained='torchvision://resnet50', + depth=50, + norm_eval=False, + shift_div=8), + cls_head=dict( + type='TSMHead', + num_classes=174, + in_channels=2048, + spatial_type='avg', + consensus=dict(type='AvgConsensus', dim=1), + dropout_ratio=0.5, + init_std=0.001, + is_shift=True), + # model training and testing settings + train_cfg=dict( + blending=dict(type='MixupBlending', num_classes=174, alpha=.2)), + test_cfg=dict(average_clips='prob')) + +# dataset settings +dataset_type = 'RawframeDataset' +data_root = 'data/sthv1/rawframes' +data_root_val = 'data/sthv1/rawframes' +ann_file_train = 'data/sthv1/sthv1_train_list_rawframes.txt' +ann_file_val = 'data/sthv1/sthv1_val_list_rawframes.txt' +ann_file_test = 'data/sthv1/sthv1_val_list_rawframes.txt' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False) +train_pipeline = [ + dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=8), + dict(type='RawFrameDecode'), + dict(type='Resize', scale=(-1, 256)), + dict( + type='MultiScaleCrop', + input_size=224, + scales=(1, 0.875, 0.75, 0.66), + random_crop=False, + max_wh_scale_gap=1, + num_fixed_crops=13), + dict(type='Resize', scale=(224, 224), keep_ratio=False), + dict(type='Normalize', **img_norm_cfg), + dict(type='FormatShape', input_format='NCHW'), + dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), + dict(type='ToTensor', keys=['imgs', 'label']) +] +val_pipeline = [ + dict( + type='SampleFrames', + clip_len=1, + frame_interval=1, + num_clips=8, + test_mode=True), + dict(type='RawFrameDecode'), + dict(type='Resize', scale=(-1, 256)), + dict(type='CenterCrop', crop_size=224), + dict(type='Normalize', **img_norm_cfg), + dict(type='FormatShape', input_format='NCHW'), + dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), + dict(type='ToTensor', keys=['imgs']) +] +test_pipeline = [ + dict( + type='SampleFrames', + clip_len=1, + frame_interval=1, + num_clips=8, + twice_sample=True, + test_mode=True), + dict(type='RawFrameDecode'), + dict(type='Resize', scale=(-1, 256)), + dict(type='ThreeCrop', crop_size=256), + dict(type='Normalize', **img_norm_cfg), + dict(type='FormatShape', input_format='NCHW'), + dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), + dict(type='ToTensor', keys=['imgs']) +] +data = dict( + videos_per_gpu=8, + workers_per_gpu=4, + train=dict( + type=dataset_type, + ann_file=ann_file_train, + data_prefix=data_root, + filename_tmpl='{:05}.jpg', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + ann_file=ann_file_val, + data_prefix=data_root_val, + filename_tmpl='{:05}.jpg', + pipeline=val_pipeline), + test=dict( + type=dataset_type, + ann_file=ann_file_test, + data_prefix=data_root_val, + filename_tmpl='{:05}.jpg', + pipeline=test_pipeline)) +evaluation = dict( + interval=2, metrics=['top_k_accuracy', 'mean_class_accuracy']) + +# optimizer +optimizer = dict(weight_decay=0.0005) + +# runtime settings +work_dir = './work_dirs/tsm_r50_mixup_1x1x8_50e_sthv1_rgb/' diff --git a/configs/recognition/tsn/tsn_r50_video_mixup_1x1x8_100e_kinetics400_rgb.py b/configs/recognition/tsn/tsn_r50_video_mixup_1x1x8_100e_kinetics400_rgb.py new file mode 100644 index 0000000000..e2bf027903 --- /dev/null +++ b/configs/recognition/tsn/tsn_r50_video_mixup_1x1x8_100e_kinetics400_rgb.py @@ -0,0 +1,110 @@ +_base_ = [ + '../../_base_/schedules/sgd_100e.py', '../../_base_/default_runtime.py' +] + +# model settings +model = dict( + type='Recognizer2D', + backbone=dict( + type='ResNet', + pretrained='torchvision://resnet50', + depth=50, + norm_eval=False), + cls_head=dict( + type='TSNHead', + num_classes=400, + in_channels=2048, + spatial_type='avg', + consensus=dict(type='AvgConsensus', dim=1), + dropout_ratio=0.4, + init_std=0.01), + # model training and testing settings + # train_cfg=dict( + # blending=dict(type="CutmixBlending", num_classes=400, alpha=.2)), + train_cfg=dict( + blending=dict(type='MixupBlending', num_classes=400, alpha=.2)), + test_cfg=dict(average_clips=None)) + +# dataset settings +dataset_type = 'VideoDataset' +data_root = 'data/kinetics400/videos_train' +data_root_val = 'data/kinetics400/videos_val' +ann_file_train = 'data/kinetics400/kinetics400_train_list_videos.txt' +ann_file_val = 'data/kinetics400/kinetics400_val_list_videos.txt' +ann_file_test = 'data/kinetics400/kinetics400_val_list_videos.txt' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False) +train_pipeline = [ + dict(type='DecordInit'), + dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=8), + dict(type='DecordDecode'), + dict( + type='MultiScaleCrop', + input_size=224, + scales=(1, 0.875, 0.75, 0.66), + random_crop=False, + max_wh_scale_gap=1), + dict(type='Resize', scale=(224, 224), keep_ratio=False), + dict(type='Flip', flip_ratio=0.5), + dict(type='Normalize', **img_norm_cfg), + dict(type='FormatShape', input_format='NCHW'), + dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), + dict(type='ToTensor', keys=['imgs', 'label']) +] +val_pipeline = [ + dict(type='DecordInit'), + dict( + type='SampleFrames', + clip_len=1, + frame_interval=1, + num_clips=8, + test_mode=True), + dict(type='DecordDecode'), + dict(type='Resize', scale=(-1, 256)), + dict(type='CenterCrop', crop_size=224), + dict(type='Flip', flip_ratio=0), + dict(type='Normalize', **img_norm_cfg), + dict(type='FormatShape', input_format='NCHW'), + dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), + dict(type='ToTensor', keys=['imgs']) +] +test_pipeline = [ + dict(type='DecordInit'), + dict( + type='SampleFrames', + clip_len=1, + frame_interval=1, + num_clips=25, + test_mode=True), + dict(type='DecordDecode'), + dict(type='Resize', scale=(-1, 256)), + dict(type='ThreeCrop', crop_size=256), + dict(type='Flip', flip_ratio=0), + dict(type='Normalize', **img_norm_cfg), + dict(type='FormatShape', input_format='NCHW'), + dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), + dict(type='ToTensor', keys=['imgs']) +] +data = dict( + videos_per_gpu=32, + workers_per_gpu=4, + train=dict( + type=dataset_type, + ann_file=ann_file_train, + data_prefix=data_root, + pipeline=train_pipeline), + val=dict( + type=dataset_type, + ann_file=ann_file_val, + data_prefix=data_root_val, + pipeline=val_pipeline), + test=dict( + type=dataset_type, + ann_file=ann_file_test, + data_prefix=data_root_val, + pipeline=test_pipeline)) +evaluation = dict( + interval=5, metrics=['top_k_accuracy', 'mean_class_accuracy']) + +# runtime settings +work_dir = './work_dirs/tsn_r50_video_mixup_1x1x8_100e_kinetics400_rgb/' diff --git a/docs/changelog.md b/docs/changelog.md index 4190294dad..60a69cfba8 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -8,6 +8,7 @@ - Support LFB ([#553](https://github.com/open-mmlab/mmaction2/pull/553)) - Support using backbones from MMCls for TSN ([#679](https://github.com/open-mmlab/mmaction2/pull/679)) +- Support Mixup and Cutmix for recognizers [#681](https://github.com/open-mmlab/mmaction2/pull/681) **Improvements** diff --git a/mmaction/datasets/__init__.py b/mmaction/datasets/__init__.py index 83d5a390b8..48341c9383 100644 --- a/mmaction/datasets/__init__.py +++ b/mmaction/datasets/__init__.py @@ -4,6 +4,8 @@ from .audio_visual_dataset import AudioVisualDataset from .ava_dataset import AVADataset from .base import BaseDataset +from .blending_utils import (BaseMiniBatchBlending, CutmixBlending, + MixupBlending) from .builder import build_dataloader, build_dataset from .dataset_wrappers import RepeatDataset from .hvu_dataset import HVUDataset @@ -17,5 +19,6 @@ 'VideoDataset', 'build_dataloader', 'build_dataset', 'RepeatDataset', 'RawframeDataset', 'BaseDataset', 'ActivityNetDataset', 'SSNDataset', 'HVUDataset', 'AudioDataset', 'AudioFeatureDataset', 'ImageDataset', - 'RawVideoDataset', 'AVADataset', 'AudioVisualDataset' + 'RawVideoDataset', 'AVADataset', 'AudioVisualDataset', + 'BaseMiniBatchBlending', 'CutmixBlending', 'MixupBlending' ] diff --git a/mmaction/datasets/blending_utils.py b/mmaction/datasets/blending_utils.py new file mode 100644 index 0000000000..88eb541f61 --- /dev/null +++ b/mmaction/datasets/blending_utils.py @@ -0,0 +1,142 @@ +from abc import ABCMeta, abstractmethod + +import torch +import torch.nn.functional as F +from torch.distributions.beta import Beta + +from .registry import BLENDINGS + +__all__ = ['BaseMiniBatchBlending', 'MixupBlending', 'CutmixBlending'] + + +class BaseMiniBatchBlending(metaclass=ABCMeta): + """Base class for Image Aliasing.""" + + def __init__(self, num_classes): + self.num_classes = num_classes + + @abstractmethod + def do_blending(self, imgs, label, **kwargs): + pass + + def __call__(self, imgs, label, **kwargs): + """Blending data in a mini-batch. + + Images are float tensors with the shape of (B, N, C, H, W) for 2D + recognizers or (B, N, C, T, H, W) for 3D recognizers. + + Besides, labels are converted from hard labels to soft labels. + Hard labels are integer tensors with the shape of (B, 1) and all of the + elements are in the range [0, num_classes - 1]. + Soft labels (probablity distribution over classes) are float tensors + with the shape of (B, 1, num_classes) and all of the elements are in + the range [0, 1]. + + Args: + imgs (torch.Tensor): Model input images, float tensor with the + shape of (B, N, C, H, W) or (B, N, C, T, H, W). + label (torch.Tensor): Hard labels, integer tensor with the shape + of (B, 1) and all elements are in range [0, num_classes). + kwargs (dict, optional): Other keyword argument to be used to + blending imgs and labels in a mini-batch. + + Returns: + mixed_imgs (torch.Tensor): Blending images, float tensor with the + same shape of the input imgs. + mixed_label (torch.Tensor): Blended soft labels, float tensor with + the shape of (B, 1, num_classes) and all elements are in range + [0, 1]. + """ + one_hot_label = F.one_hot(label, num_classes=self.num_classes) + + mixed_imgs, mixed_label = self.do_blending(imgs, one_hot_label, + **kwargs) + + return mixed_imgs, mixed_label + + +@BLENDINGS.register_module() +class MixupBlending(BaseMiniBatchBlending): + """Implementing Mixup in a mini-batch. + + This module is proposed in `mixup: Beyond Empirical Risk Minimization + `_. + Code Reference https://github.com/open-mmlab/mmclassification/blob/master/mmcls/models/utils/mixup.py # noqa + + Args: + num_classes (int): The number of classes. + alpha (float): Parameters for Beta distribution. + """ + + def __init__(self, num_classes, alpha=.2): + super().__init__(num_classes=num_classes) + self.beta = Beta(alpha, alpha) + + def do_blending(self, imgs, label, **kwargs): + """Blending images with mixup.""" + assert len(kwargs) == 0, f'unexpected kwargs for mixup {kwargs}' + + lam = self.beta.sample() + batch_size = imgs.size(0) + rand_index = torch.randperm(batch_size) + + mixed_imgs = lam * imgs + (1 - lam) * imgs[rand_index, :] + mixed_label = lam * label + (1 - lam) * label[rand_index, :] + + return mixed_imgs, mixed_label + + +@BLENDINGS.register_module() +class CutmixBlending(BaseMiniBatchBlending): + """Implementing Cutmix in a mini-batch. + + This module is proposed in `CutMix: Regularization Strategy to Train Strong + Classifiers with Localizable Features `_. + Code Reference https://github.com/clovaai/CutMix-PyTorch + + Args: + num_classes (int): The number of classes. + alpha (float): Parameters for Beta distribution. + """ + + def __init__(self, num_classes, alpha=.2): + super().__init__(num_classes=num_classes) + self.beta = Beta(alpha, alpha) + + @staticmethod + def rand_bbox(img_size, lam): + """Generate a random boudning box.""" + w = img_size[-1] + h = img_size[-2] + cut_rat = torch.sqrt(1. - lam) + cut_w = torch.tensor(int(w * cut_rat)) + cut_h = torch.tensor(int(h * cut_rat)) + + # uniform + cx = torch.randint(w, (1, ))[0] + cy = torch.randint(h, (1, ))[0] + + bbx1 = torch.clamp(cx - cut_w // 2, 0, w) + bby1 = torch.clamp(cy - cut_h // 2, 0, h) + bbx2 = torch.clamp(cx + cut_w // 2, 0, w) + bby2 = torch.clamp(cy + cut_h // 2, 0, h) + + return bbx1, bby1, bbx2, bby2 + + def do_blending(self, imgs, label, **kwargs): + """Blending images with cutmix.""" + assert len(kwargs) == 0, f'unexpected kwargs for cutmix {kwargs}' + + batch_size = imgs.size(0) + rand_index = torch.randperm(batch_size) + lam = self.beta.sample() + + bbx1, bby1, bbx2, bby2 = self.rand_bbox(imgs.size(), lam) + imgs[:, ..., bby1:bby2, bbx1:bbx2] = imgs[rand_index, ..., bby1:bby2, + bbx1:bbx2] + lam = 1 - (1.0 * (bbx2 - bbx1) * (bby2 - bby1) / + (imgs.size()[-1] * imgs.size()[-2])) + + label = lam * label + (1 - lam) * label[rand_index, :] + + return imgs, label diff --git a/mmaction/datasets/registry.py b/mmaction/datasets/registry.py index 2ebc753cd3..ade60f11b3 100644 --- a/mmaction/datasets/registry.py +++ b/mmaction/datasets/registry.py @@ -2,3 +2,4 @@ DATASETS = Registry('dataset') PIPELINES = Registry('pipeline') +BLENDINGS = Registry('blending') diff --git a/mmaction/models/heads/base.py b/mmaction/models/heads/base.py index 91abacd124..7815f6838d 100644 --- a/mmaction/models/heads/base.py +++ b/mmaction/models/heads/base.py @@ -79,8 +79,14 @@ def loss(self, cls_score, labels, **kwargs): losses = dict() if labels.shape == torch.Size([]): labels = labels.unsqueeze(0) + elif labels.dim() == 1 and labels.size()[0] == self.num_classes \ + and cls_score.size()[0] == 1: + # Fix a bug when training with soft labels and batch size is 1. + # When using soft labels, `labels` and `cls_socre` share the same + # shape. + labels = labels.unsqueeze(0) - if not self.multi_class: + if not self.multi_class and cls_score.size() != labels.size(): top_k_acc = top_k_accuracy(cls_score.detach().cpu().numpy(), labels.detach().cpu().numpy(), (1, 5)) losses['top1_acc'] = torch.tensor( @@ -88,7 +94,7 @@ def loss(self, cls_score, labels, **kwargs): losses['top5_acc'] = torch.tensor( top_k_acc[1], device=cls_score.device) - elif self.label_smooth_eps != 0: + elif self.multi_class and self.label_smooth_eps != 0: labels = ((1 - self.label_smooth_eps) * labels + self.label_smooth_eps / self.num_classes) diff --git a/mmaction/models/recognizers/base.py b/mmaction/models/recognizers/base.py index 6df4f41111..588935f12d 100644 --- a/mmaction/models/recognizers/base.py +++ b/mmaction/models/recognizers/base.py @@ -63,6 +63,13 @@ def __init__(self, self.max_testing_views = test_cfg['max_testing_views'] assert isinstance(self.max_testing_views, int) + # mini-batch blending, e.g. mixup, cutmix, etc. + self.blending = None + if train_cfg is not None and 'blending' in train_cfg: + from mmcv.utils import build_from_cfg + from ...datasets.registry import BLENDINGS + self.blending = build_from_cfg(train_cfg['blending'], BLENDINGS) + self.init_weights() self.fp16_enabled = False @@ -181,6 +188,8 @@ def forward(self, imgs, label=None, return_loss=True, **kwargs): if return_loss: if label is None: raise ValueError('Label should not be None.') + if self.blending is not None: + imgs, label = self.blending(imgs, label) return self.forward_train(imgs, label, **kwargs) return self.forward_test(imgs, **kwargs) diff --git a/tests/test_data/test_blending.py b/tests/test_data/test_blending.py new file mode 100644 index 0000000000..53f4e7bcfe --- /dev/null +++ b/tests/test_data/test_blending.py @@ -0,0 +1,41 @@ +import torch + +from mmaction.datasets import CutmixBlending, MixupBlending + + +def test_mixup(): + alpha = 0.2 + num_classes = 10 + label = torch.randint(0, num_classes, (4, )) + mixup = MixupBlending(num_classes, alpha) + + # NCHW imgs + imgs = torch.randn(4, 4, 3, 32, 32) + mixed_imgs, mixed_label = mixup(imgs, label) + assert mixed_imgs.shape == torch.Size((4, 4, 3, 32, 32)) + assert mixed_label.shape == torch.Size((4, num_classes)) + + # NCTHW imgs + imgs = torch.randn(4, 4, 2, 3, 32, 32) + mixed_imgs, mixed_label = mixup(imgs, label) + assert mixed_imgs.shape == torch.Size((4, 4, 2, 3, 32, 32)) + assert mixed_label.shape == torch.Size((4, num_classes)) + + +def test_cutmix(): + alpha = 0.2 + num_classes = 10 + label = torch.randint(0, num_classes, (4, )) + mixup = CutmixBlending(num_classes, alpha) + + # NCHW imgs + imgs = torch.randn(4, 4, 3, 32, 32) + mixed_imgs, mixed_label = mixup(imgs, label) + assert mixed_imgs.shape == torch.Size((4, 4, 3, 32, 32)) + assert mixed_label.shape == torch.Size((4, num_classes)) + + # NCTHW imgs + imgs = torch.randn(4, 4, 2, 3, 32, 32) + mixed_imgs, mixed_label = mixup(imgs, label) + assert mixed_imgs.shape == torch.Size((4, 4, 2, 3, 32, 32)) + assert mixed_label.shape == torch.Size((4, num_classes)) diff --git a/tests/test_models/test_common_modules/test_base_head.py b/tests/test_models/test_common_modules/test_base_head.py index 0eeebcd3ee..6611657468 100644 --- a/tests/test_models/test_common_modules/test_base_head.py +++ b/tests/test_models/test_common_modules/test_base_head.py @@ -1,4 +1,5 @@ import torch +import torch.nn.functional as F from mmcv.utils import assert_dict_has_keys from mmaction.models import BaseHead @@ -31,3 +32,41 @@ def test_base_head(): losses = head.loss(cls_scores, gt_labels) assert_dict_has_keys(losses, ['loss_cls']) assert losses.get('loss_cls') > 0, 'cls loss should be non-zero' + + # Test Soft label with batch size > 1 + cls_scores = torch.rand((3, 3)) + gt_labels = torch.LongTensor([[2] * 3]) + gt_one_hot_labels = F.one_hot(gt_labels, num_classes=3).squeeze() + losses = head.loss(cls_scores, gt_one_hot_labels) + assert 'loss_cls' in losses.keys() + assert losses.get('loss_cls') > 0, 'cls loss should be non-zero' + + # Test Soft label with batch size = 1 + cls_scores = torch.rand((1, 3)) + gt_labels = torch.LongTensor([2]) + gt_one_hot_labels = F.one_hot(gt_labels, num_classes=3).squeeze() + losses = head.loss(cls_scores, gt_one_hot_labels) + assert 'loss_cls' in losses.keys() + assert losses.get('loss_cls') > 0, 'cls loss should be non-zero' + + # test multi-class & label smoothing + head = ExampleHead( + 3, + 400, + dict(type='BCELossWithLogits'), + multi_class=True, + label_smooth_eps=0.1) + + # batch size > 1 + cls_scores = torch.rand((2, 3)) + gt_labels = torch.LongTensor([[1, 0, 1], [0, 1, 0]]).squeeze() + losses = head.loss(cls_scores, gt_labels) + assert 'loss_cls' in losses.keys() + assert losses.get('loss_cls') > 0, 'cls loss should be non-zero' + + # batch size = 1 + cls_scores = torch.rand((1, 3)) + gt_labels = torch.LongTensor([[1, 0, 1]]).squeeze() + losses = head.loss(cls_scores, gt_labels) + assert 'loss_cls' in losses.keys() + assert losses.get('loss_cls') > 0, 'cls loss should be non-zero' diff --git a/tests/test_models/test_recognizers/test_recognizer2d.py b/tests/test_models/test_recognizers/test_recognizer2d.py index c4828cb257..b0eba84f42 100644 --- a/tests/test_models/test_recognizers/test_recognizer2d.py +++ b/tests/test_models/test_recognizers/test_recognizer2d.py @@ -57,6 +57,18 @@ def test_tsn(): for one_img in img_list: recognizer(one_img, None, return_loss=False) + # test mixup forward + config = get_recognizer_cfg( + 'tsn/tsn_r50_video_mixup_1x1x8_100e_kinetics400_rgb.py') + config.model['backbone']['pretrained'] = None + recognizer = build_recognizer(config.model) + input_shape = (2, 8, 3, 32, 32) + demo_inputs = generate_recognizer_demo_inputs(input_shape) + imgs = demo_inputs['imgs'] + gt_labels = demo_inputs['gt_labels'] + losses = recognizer(imgs, gt_labels) + assert isinstance(losses, dict) + def test_tsm(): config = get_recognizer_cfg('tsm/tsm_r50_1x1x8_50e_kinetics400_rgb.py')