Skip to content

Commit

Permalink
[Feature] Support Mixup and Cutmix for Recognizers. (#681)
Browse files Browse the repository at this point in the history
* add todo list

* codes of mixup/cutmix/register/recognizers

* add unittest

* add demo config

* fix unittest

* remove toto list

* update changelog

* fix unittest and training bug

* fix

* fix unittest

* add todo

* remove useless codes

* update comments

* update docs

* fix

* fix a bug

* update configs

* update sthv1 training results

* add tsn config and modify default alpha value

* fix lint

* add unittest

* fix tin sthv2 config

* update links

* remove useless docs

Co-authored-by: Kenny <dhd.efz@gmail.com>
  • Loading branch information
irvingzhang0512 and kennymckormick committed Mar 24, 2021
1 parent 3268fee commit fbe62f6
Show file tree
Hide file tree
Showing 14 changed files with 604 additions and 6 deletions.
3 changes: 0 additions & 3 deletions configs/recognition/tin/tin_r50_1x1x8_40e_sthv2_rgb.py
Expand Up @@ -65,19 +65,16 @@
type=dataset_type,
ann_file=ann_file_train,
data_prefix=data_root,
filename_tmpl='{:05}.jpg',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=ann_file_val,
data_prefix=data_root_val,
filename_tmpl='{:05}.jpg',
pipeline=val_pipeline),
test=dict(
type=dataset_type,
ann_file=ann_file_test,
data_prefix=data_root_val,
filename_tmpl='{:05}.jpg',
pipeline=test_pipeline))
evaluation = dict(
interval=2, metrics=['top_k_accuracy', 'mean_class_accuracy'])
Expand Down
9 changes: 9 additions & 0 deletions configs/recognition/tsm/README.md
Expand Up @@ -60,6 +60,13 @@
|[tsm_r50_1x1x16_50e_sthv2_rgb](/configs/recognition/tsm/tsm_r50_1x1x16_50e_sthv2_rgb.py) |height 240|8| ResNet50| ImageNet |59.93 / 62.04|86.10 / 87.35|[58.90 / 60.98](https://github.com/mit-han-lab/temporal-shift-module/tree/8d53d6fda40bea2f1b37a6095279c4b454d672bd#training)|[85.29 / 86.60](https://github.com/mit-han-lab/temporal-shift-module/tree/8d53d6fda40bea2f1b37a6095279c4b454d672bd#training)| 10400| [ckpt](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_r50_1x1x16_50e_sthv2_rgb/tsm_r50_1x1x16_50e_sthv2_rgb_20201010-16469c6f.pth) | [log](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_r50_1x1x16_50e_sthv2_rgb/20201010_224215.log)| [json](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_r50_1x1x16_50e_sthv2_rgb/20201010_224215.log.json)|
|[tsm_r101_1x1x8_50e_sthv2_rgb](/configs/recognition/tsm/tsm_r101_1x1x8_50e_sthv2_rgb.py) |height 240|8| ResNet101 | ImageNet|58.59 / 61.51|85.07 / 86.90|[58.89 / 61.36](https://github.com/mit-han-lab/temporal-shift-module/tree/8d53d6fda40bea2f1b37a6095279c4b454d672bd#training)|[85.14 / 87.00](https://github.com/mit-han-lab/temporal-shift-module/tree/8d53d6fda40bea2f1b37a6095279c4b454d672bd#training)| 9784 | [ckpt](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_r101_1x1x8_50e_sthv2_rgb/tsm_r101_1x1x8_50e_sthv2_rgb_20201010-98cdedb8.pth) | [log](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_r101_1x1x8_50e_sthv2_rgb/20201010_224100.log)| [json](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_r101_1x1x8_50e_sthv2_rgb/20201010_224100.log.json)|

### MixUp & CutMix on Something-Something V1

| config | resolution | gpus | backbone | pretrain | top1 acc (efficient/accurate) | top5 acc (efficient/accurate) | delta top1 acc (efficient/accurate) | delta top5 acc (efficient/accurate) | ckpt | log | json |
| :----------------------------------------------------------- | :--------: | :--: | :------: | :------: | :---------------------------: | :---------------------------: | :---------------------------------: | :---------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: |
| [tsm_r50_mixup_1x1x8_50e_sthv1_rgb](/configs/recognition/tsm/tsm_r50_mixup_1x1x8_50e_sthv1_rgb.py) | height 100 | 8 | ResNet50 | ImageNet | 46.35 / 48.49 | 75.07 / 76.88 | +0.77 / +0.79 | +0.05 / +0.70 | [ckpt](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_r50_mixup_1x1x8_50e_sthv1_rgb/tsm_r50_mixup_1x1x8_50e_sthv1_rgb-9eca48e5.pth) | [log](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_r50_mixup_1x1x8_50e_sthv1_rgb/tsm_r50_mixup_1x1x8_50e_sthv1_rgb.log) | [json](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_r50_mixup_1x1x8_50e_sthv1_rgb/tsm_r50_mixup_1x1x8_50e_sthv1_rgb.json) |
| [tsm_r50_cutmix_1x1x8_50e_sthv1_rgb](/configs/recognition/tsm/tsm_r50_cutmix_1x1x8_50e_sthv1_rgb.py) | height 100 | 8 | ResNet50 | ImageNet | 45.92 / 47.46 | 75.23 / 76.71 | +0.34 / -0.24 | +0.21 / +0.59 | [ckpt](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_r50_cutmix_1x1x8_50e_sthv1_rgb/tsm_r50_cutmix_1x1x8_50e_sthv1_rgb-34934615.pth) | [log](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_r50_cutmix_1x1x8_50e_sthv1_rgb/tsm_r50_cutmix_1x1x8_50e_sthv1_rgb.log) | [json](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_r50_cutmix_1x1x8_50e_sthv1_rgb/tsm_r50_cutmix_1x1x8_50e_sthv1_rgb.json) |

Notes:

1. The **gpus** indicates the number of gpu we used to get the checkpoint. It is noteworthy that the configs we provide are used for 8 gpus as default.
Expand Down Expand Up @@ -94,6 +101,8 @@ test_pipeline = [

For more details on data preparation, you can refer to Kinetics400, Something-Something V1 and Something-Something V2 in [Data Preparation](/docs/data_preparation.md).

5. When applying Mixup and CutMix, we use the hyper parameter `alpha=0.2`.

## Train

You can use the following command to train a model.
Expand Down
114 changes: 114 additions & 0 deletions configs/recognition/tsm/tsm_r50_cutmix_1x1x8_50e_sthv1_rgb.py
@@ -0,0 +1,114 @@
_base_ = [
'../../_base_/schedules/sgd_tsm_50e.py', '../../_base_/default_runtime.py'
]

# model settings
# model settings# model settings
model = dict(
type='Recognizer2D',
backbone=dict(
type='ResNetTSM',
pretrained='torchvision://resnet50',
depth=50,
norm_eval=False,
shift_div=8),
cls_head=dict(
type='TSMHead',
num_classes=174,
in_channels=2048,
spatial_type='avg',
consensus=dict(type='AvgConsensus', dim=1),
dropout_ratio=0.5,
init_std=0.001,
is_shift=True),
# model training and testing settings
train_cfg=dict(
blending=dict(type='CutmixBlending', num_classes=174, alpha=.2)),
test_cfg=dict(average_clips='prob'))

# dataset settings
dataset_type = 'RawframeDataset'
data_root = 'data/sthv1/rawframes'
data_root_val = 'data/sthv1/rawframes'
ann_file_train = 'data/sthv1/sthv1_train_list_rawframes.txt'
ann_file_val = 'data/sthv1/sthv1_val_list_rawframes.txt'
ann_file_test = 'data/sthv1/sthv1_val_list_rawframes.txt'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False)
train_pipeline = [
dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=8),
dict(type='RawFrameDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(
type='MultiScaleCrop',
input_size=224,
scales=(1, 0.875, 0.75, 0.66),
random_crop=False,
max_wh_scale_gap=1,
num_fixed_crops=13),
dict(type='Resize', scale=(224, 224), keep_ratio=False),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs', 'label'])
]
val_pipeline = [
dict(
type='SampleFrames',
clip_len=1,
frame_interval=1,
num_clips=8,
test_mode=True),
dict(type='RawFrameDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='CenterCrop', crop_size=224),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
test_pipeline = [
dict(
type='SampleFrames',
clip_len=1,
frame_interval=1,
num_clips=8,
twice_sample=True,
test_mode=True),
dict(type='RawFrameDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='ThreeCrop', crop_size=256),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
data = dict(
videos_per_gpu=8,
workers_per_gpu=4,
train=dict(
type=dataset_type,
ann_file=ann_file_train,
data_prefix=data_root,
filename_tmpl='{:05}.jpg',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=ann_file_val,
data_prefix=data_root_val,
filename_tmpl='{:05}.jpg',
pipeline=val_pipeline),
test=dict(
type=dataset_type,
ann_file=ann_file_test,
data_prefix=data_root_val,
filename_tmpl='{:05}.jpg',
pipeline=test_pipeline))
evaluation = dict(
interval=2, metrics=['top_k_accuracy', 'mean_class_accuracy'])

# optimizer
optimizer = dict(weight_decay=0.0005)

# runtime settings
work_dir = './work_dirs/tsm_r50_cutmix_1x1x8_50e_sthv1_rgb/'
114 changes: 114 additions & 0 deletions configs/recognition/tsm/tsm_r50_mixup_1x1x8_50e_sthv1_rgb.py
@@ -0,0 +1,114 @@
_base_ = [
'../../_base_/schedules/sgd_tsm_50e.py', '../../_base_/default_runtime.py'
]

# model settings
# model settings# model settings
model = dict(
type='Recognizer2D',
backbone=dict(
type='ResNetTSM',
pretrained='torchvision://resnet50',
depth=50,
norm_eval=False,
shift_div=8),
cls_head=dict(
type='TSMHead',
num_classes=174,
in_channels=2048,
spatial_type='avg',
consensus=dict(type='AvgConsensus', dim=1),
dropout_ratio=0.5,
init_std=0.001,
is_shift=True),
# model training and testing settings
train_cfg=dict(
blending=dict(type='MixupBlending', num_classes=174, alpha=.2)),
test_cfg=dict(average_clips='prob'))

# dataset settings
dataset_type = 'RawframeDataset'
data_root = 'data/sthv1/rawframes'
data_root_val = 'data/sthv1/rawframes'
ann_file_train = 'data/sthv1/sthv1_train_list_rawframes.txt'
ann_file_val = 'data/sthv1/sthv1_val_list_rawframes.txt'
ann_file_test = 'data/sthv1/sthv1_val_list_rawframes.txt'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False)
train_pipeline = [
dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=8),
dict(type='RawFrameDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(
type='MultiScaleCrop',
input_size=224,
scales=(1, 0.875, 0.75, 0.66),
random_crop=False,
max_wh_scale_gap=1,
num_fixed_crops=13),
dict(type='Resize', scale=(224, 224), keep_ratio=False),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs', 'label'])
]
val_pipeline = [
dict(
type='SampleFrames',
clip_len=1,
frame_interval=1,
num_clips=8,
test_mode=True),
dict(type='RawFrameDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='CenterCrop', crop_size=224),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
test_pipeline = [
dict(
type='SampleFrames',
clip_len=1,
frame_interval=1,
num_clips=8,
twice_sample=True,
test_mode=True),
dict(type='RawFrameDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='ThreeCrop', crop_size=256),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
data = dict(
videos_per_gpu=8,
workers_per_gpu=4,
train=dict(
type=dataset_type,
ann_file=ann_file_train,
data_prefix=data_root,
filename_tmpl='{:05}.jpg',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=ann_file_val,
data_prefix=data_root_val,
filename_tmpl='{:05}.jpg',
pipeline=val_pipeline),
test=dict(
type=dataset_type,
ann_file=ann_file_test,
data_prefix=data_root_val,
filename_tmpl='{:05}.jpg',
pipeline=test_pipeline))
evaluation = dict(
interval=2, metrics=['top_k_accuracy', 'mean_class_accuracy'])

# optimizer
optimizer = dict(weight_decay=0.0005)

# runtime settings
work_dir = './work_dirs/tsm_r50_mixup_1x1x8_50e_sthv1_rgb/'
@@ -0,0 +1,110 @@
_base_ = [
'../../_base_/schedules/sgd_100e.py', '../../_base_/default_runtime.py'
]

# model settings
model = dict(
type='Recognizer2D',
backbone=dict(
type='ResNet',
pretrained='torchvision://resnet50',
depth=50,
norm_eval=False),
cls_head=dict(
type='TSNHead',
num_classes=400,
in_channels=2048,
spatial_type='avg',
consensus=dict(type='AvgConsensus', dim=1),
dropout_ratio=0.4,
init_std=0.01),
# model training and testing settings
# train_cfg=dict(
# blending=dict(type="CutmixBlending", num_classes=400, alpha=.2)),
train_cfg=dict(
blending=dict(type='MixupBlending', num_classes=400, alpha=.2)),
test_cfg=dict(average_clips=None))

# dataset settings
dataset_type = 'VideoDataset'
data_root = 'data/kinetics400/videos_train'
data_root_val = 'data/kinetics400/videos_val'
ann_file_train = 'data/kinetics400/kinetics400_train_list_videos.txt'
ann_file_val = 'data/kinetics400/kinetics400_val_list_videos.txt'
ann_file_test = 'data/kinetics400/kinetics400_val_list_videos.txt'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False)
train_pipeline = [
dict(type='DecordInit'),
dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=8),
dict(type='DecordDecode'),
dict(
type='MultiScaleCrop',
input_size=224,
scales=(1, 0.875, 0.75, 0.66),
random_crop=False,
max_wh_scale_gap=1),
dict(type='Resize', scale=(224, 224), keep_ratio=False),
dict(type='Flip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs', 'label'])
]
val_pipeline = [
dict(type='DecordInit'),
dict(
type='SampleFrames',
clip_len=1,
frame_interval=1,
num_clips=8,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='CenterCrop', crop_size=224),
dict(type='Flip', flip_ratio=0),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
test_pipeline = [
dict(type='DecordInit'),
dict(
type='SampleFrames',
clip_len=1,
frame_interval=1,
num_clips=25,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='ThreeCrop', crop_size=256),
dict(type='Flip', flip_ratio=0),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
data = dict(
videos_per_gpu=32,
workers_per_gpu=4,
train=dict(
type=dataset_type,
ann_file=ann_file_train,
data_prefix=data_root,
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=ann_file_val,
data_prefix=data_root_val,
pipeline=val_pipeline),
test=dict(
type=dataset_type,
ann_file=ann_file_test,
data_prefix=data_root_val,
pipeline=test_pipeline))
evaluation = dict(
interval=5, metrics=['top_k_accuracy', 'mean_class_accuracy'])

# runtime settings
work_dir = './work_dirs/tsn_r50_video_mixup_1x1x8_100e_kinetics400_rgb/'
1 change: 1 addition & 0 deletions docs/changelog.md
Expand Up @@ -8,6 +8,7 @@

- Support LFB ([#553](https://github.com/open-mmlab/mmaction2/pull/553))
- Support using backbones from MMCls for TSN ([#679](https://github.com/open-mmlab/mmaction2/pull/679))
- Support Mixup and Cutmix for recognizers [#681](https://github.com/open-mmlab/mmaction2/pull/681)

**Improvements**

Expand Down

0 comments on commit fbe62f6

Please sign in to comment.