diff --git a/.gitignore b/.gitignore index aac23ab8be0..04a725b73c7 100644 --- a/.gitignore +++ b/.gitignore @@ -133,3 +133,6 @@ venv.bak/ *.pvti-journal /cache_engine /report + +# slurm +*.out diff --git a/README.md b/README.md index 30ffebac62b..0ff65c5127c 100644 --- a/README.md +++ b/README.md @@ -127,6 +127,7 @@ Results and models are available in the [model zoo](https://mmclassification.rea - [x] [Res2Net](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/res2net) - [x] [MLP-Mixer](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mlp_mixer) - [x] [DeiT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/deit) +- [x] [DeiT-3](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/deit3) - [x] [Conformer](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/conformer) - [x] [T2T-ViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/t2t_vit) - [x] [Twins](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/twins) diff --git a/README_zh-CN.md b/README_zh-CN.md index 0b3ae76327b..4abfe1b5469 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -126,6 +126,7 @@ mim install -e . - [x] [Res2Net](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/res2net) - [x] [MLP-Mixer](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mlp_mixer) - [x] [DeiT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/deit) +- [x] [DeiT-3](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/deit3) - [x] [Conformer](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/conformer) - [x] [T2T-ViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/t2t_vit) - [x] [Twins](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/twins) diff --git a/configs/_base_/datasets/imagenet_bs64_deit3_224.py b/configs/_base_/datasets/imagenet_bs64_deit3_224.py new file mode 100644 index 00000000000..755a7f9ad94 --- /dev/null +++ b/configs/_base_/datasets/imagenet_bs64_deit3_224.py @@ -0,0 +1,83 @@ +# dataset settings +dataset_type = 'ImageNet' +data_preprocessor = dict( + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True, +) + +bgr_mean = data_preprocessor['mean'][::-1] +bgr_std = data_preprocessor['std'][::-1] + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='RandomResizedCrop', + scale=224, + backend='pillow', + interpolation='bicubic'), + dict(type='RandomFlip', prob=0.5, direction='horizontal'), + dict( + type='RandAugment', + policies='timm_increasing', + num_policies=2, + total_level=10, + magnitude_level=9, + magnitude_std=0.5, + hparams=dict( + pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')), + dict( + type='RandomErasing', + erase_prob=0.25, + mode='rand', + min_area_ratio=0.02, + max_area_ratio=1 / 3, + fill_color=bgr_mean, + fill_std=bgr_std), + dict(type='PackClsInputs'), +] + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='ResizeEdge', + scale=224, + edge='short', + backend='pillow', + interpolation='bicubic'), + dict(type='CenterCrop', crop_size=224), + dict(type='PackClsInputs'), +] + +train_dataloader = dict( + batch_size=64, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + ann_file='meta/train.txt', + data_prefix='train', + pipeline=train_pipeline), + sampler=dict(type='DefaultSampler', shuffle=True), + persistent_workers=True, +) + +val_dataloader = dict( + batch_size=64, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + ann_file='meta/val.txt', + data_prefix='val', + pipeline=test_pipeline), + sampler=dict(type='DefaultSampler', shuffle=False), + persistent_workers=True, +) +val_evaluator = dict(type='Accuracy', topk=(1, 5)) + +# If you want standard test, please manually configure the test dataset +test_dataloader = val_dataloader +test_evaluator = val_evaluator diff --git a/configs/_base_/datasets/imagenet_bs64_deit3_384.py b/configs/_base_/datasets/imagenet_bs64_deit3_384.py new file mode 100644 index 00000000000..572279ee060 --- /dev/null +++ b/configs/_base_/datasets/imagenet_bs64_deit3_384.py @@ -0,0 +1,63 @@ +# dataset settings +dataset_type = 'ImageNet' +data_preprocessor = dict( + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True, +) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='RandomResizedCrop', + scale=384, + backend='pillow', + interpolation='bicubic'), + dict(type='RandomFlip', prob=0.5, direction='horizontal'), + dict(type='PackClsInputs'), +] + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='ResizeEdge', + scale=384, + edge='short', + backend='pillow', + interpolation='bicubic'), + dict(type='CenterCrop', crop_size=384), + dict(type='PackClsInputs'), +] + +train_dataloader = dict( + batch_size=64, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + ann_file='meta/train.txt', + data_prefix='train', + pipeline=train_pipeline), + sampler=dict(type='DefaultSampler', shuffle=True), + persistent_workers=True, +) + +val_dataloader = dict( + batch_size=64, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + ann_file='meta/val.txt', + data_prefix='val', + pipeline=test_pipeline), + sampler=dict(type='DefaultSampler', shuffle=False), + persistent_workers=True, +) +val_evaluator = dict(type='Accuracy', topk=(1, 5)) + +# If you want standard test, please manually configure the test dataset +test_dataloader = val_dataloader +test_evaluator = val_evaluator diff --git a/configs/_base_/models/deit3/deit3-base-p16-224.py b/configs/_base_/models/deit3/deit3-base-p16-224.py new file mode 100644 index 00000000000..1a775fcab83 --- /dev/null +++ b/configs/_base_/models/deit3/deit3-base-p16-224.py @@ -0,0 +1,24 @@ +model = dict( + type='ImageClassifier', + backbone=dict( + type='DeiT3', + arch='b', + img_size=224, + patch_size=16, + drop_path_rate=0.2), + neck=None, + head=dict( + type='VisionTransformerClsHead', + num_classes=1000, + in_channels=768, + loss=dict( + type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'), + ), + init_cfg=[ + dict(type='TruncNormal', layer='Linear', std=.02), + dict(type='Constant', layer='LayerNorm', val=1., bias=0.), + ], + train_cfg=dict(augments=[ + dict(type='Mixup', alpha=0.8, num_classes=1000), + dict(type='CutMix', alpha=1.0, num_classes=1000) + ])) diff --git a/configs/_base_/models/deit3/deit3-base-p16-384.py b/configs/_base_/models/deit3/deit3-base-p16-384.py new file mode 100644 index 00000000000..48639b9986f --- /dev/null +++ b/configs/_base_/models/deit3/deit3-base-p16-384.py @@ -0,0 +1,24 @@ +model = dict( + type='ImageClassifier', + backbone=dict( + type='DeiT3', + arch='b', + img_size=384, + patch_size=16, + drop_path_rate=0.15), + neck=None, + head=dict( + type='VisionTransformerClsHead', + num_classes=1000, + in_channels=768, + loss=dict( + type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'), + ), + init_cfg=[ + dict(type='TruncNormal', layer='Linear', std=.02), + dict(type='Constant', layer='LayerNorm', val=1., bias=0.), + ], + train_cfg=dict(augments=[ + dict(type='Mixup', alpha=0.8, num_classes=1000), + dict(type='CutMix', alpha=1.0, num_classes=1000) + ])) diff --git a/configs/_base_/models/deit3/deit3-huge-p14-224.py b/configs/_base_/models/deit3/deit3-huge-p14-224.py new file mode 100644 index 00000000000..e39595d2a8e --- /dev/null +++ b/configs/_base_/models/deit3/deit3-huge-p14-224.py @@ -0,0 +1,24 @@ +model = dict( + type='ImageClassifier', + backbone=dict( + type='DeiT3', + arch='h', + img_size=224, + patch_size=14, + drop_path_rate=0.55), + neck=None, + head=dict( + type='VisionTransformerClsHead', + num_classes=1000, + in_channels=1280, + loss=dict( + type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'), + ), + init_cfg=[ + dict(type='TruncNormal', layer='Linear', std=.02), + dict(type='Constant', layer='LayerNorm', val=1., bias=0.), + ], + train_cfg=dict(augments=[ + dict(type='Mixup', alpha=0.8, num_classes=1000), + dict(type='CutMix', alpha=1.0, num_classes=1000) + ])) diff --git a/configs/_base_/models/deit3/deit3-large-p16-224.py b/configs/_base_/models/deit3/deit3-large-p16-224.py new file mode 100644 index 00000000000..fd60a4a7908 --- /dev/null +++ b/configs/_base_/models/deit3/deit3-large-p16-224.py @@ -0,0 +1,24 @@ +model = dict( + type='ImageClassifier', + backbone=dict( + type='DeiT3', + arch='l', + img_size=224, + patch_size=16, + drop_path_rate=0.45), + neck=None, + head=dict( + type='VisionTransformerClsHead', + num_classes=1000, + in_channels=1024, + loss=dict( + type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'), + ), + init_cfg=[ + dict(type='TruncNormal', layer='Linear', std=.02), + dict(type='Constant', layer='LayerNorm', val=1., bias=0.), + ], + train_cfg=dict(augments=[ + dict(type='Mixup', alpha=0.8, num_classes=1000), + dict(type='CutMix', alpha=1.0, num_classes=1000) + ])) diff --git a/configs/_base_/models/deit3/deit3-large-p16-384.py b/configs/_base_/models/deit3/deit3-large-p16-384.py new file mode 100644 index 00000000000..364f1d24f1a --- /dev/null +++ b/configs/_base_/models/deit3/deit3-large-p16-384.py @@ -0,0 +1,24 @@ +model = dict( + type='ImageClassifier', + backbone=dict( + type='DeiT3', + arch='l', + img_size=384, + patch_size=16, + drop_path_rate=0.4), + neck=None, + head=dict( + type='VisionTransformerClsHead', + num_classes=1000, + in_channels=1024, + loss=dict( + type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'), + ), + init_cfg=[ + dict(type='TruncNormal', layer='Linear', std=.02), + dict(type='Constant', layer='LayerNorm', val=1., bias=0.), + ], + train_cfg=dict(augments=[ + dict(type='Mixup', alpha=0.8, num_classes=1000), + dict(type='CutMix', alpha=1.0, num_classes=1000) + ])) diff --git a/configs/_base_/models/deit3/deit3-medium-p16-224.py b/configs/_base_/models/deit3/deit3-medium-p16-224.py new file mode 100644 index 00000000000..4fc4e284e35 --- /dev/null +++ b/configs/_base_/models/deit3/deit3-medium-p16-224.py @@ -0,0 +1,24 @@ +model = dict( + type='ImageClassifier', + backbone=dict( + type='DeiT3', + arch='m', + img_size=224, + patch_size=16, + drop_path_rate=0.2), + neck=None, + head=dict( + type='VisionTransformerClsHead', + num_classes=1000, + in_channels=512, + loss=dict( + type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'), + ), + init_cfg=[ + dict(type='TruncNormal', layer='Linear', std=.02), + dict(type='Constant', layer='LayerNorm', val=1., bias=0.), + ], + train_cfg=dict(augments=[ + dict(type='Mixup', alpha=0.8, num_classes=1000), + dict(type='CutMix', alpha=1.0, num_classes=1000) + ])) diff --git a/configs/_base_/models/deit3/deit3-small-p16-224.py b/configs/_base_/models/deit3/deit3-small-p16-224.py new file mode 100644 index 00000000000..638a940f497 --- /dev/null +++ b/configs/_base_/models/deit3/deit3-small-p16-224.py @@ -0,0 +1,24 @@ +model = dict( + type='ImageClassifier', + backbone=dict( + type='DeiT3', + arch='s', + img_size=224, + patch_size=16, + drop_path_rate=0.05), + neck=None, + head=dict( + type='VisionTransformerClsHead', + num_classes=1000, + in_channels=384, + loss=dict( + type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'), + ), + init_cfg=[ + dict(type='TruncNormal', layer='Linear', std=.02), + dict(type='Constant', layer='LayerNorm', val=1., bias=0.), + ], + train_cfg=dict(augments=[ + dict(type='Mixup', alpha=0.8, num_classes=1000), + dict(type='CutMix', alpha=1.0, num_classes=1000) + ])) diff --git a/configs/_base_/models/deit3/deit3-small-p16-384.py b/configs/_base_/models/deit3/deit3-small-p16-384.py new file mode 100644 index 00000000000..2e17c0f775a --- /dev/null +++ b/configs/_base_/models/deit3/deit3-small-p16-384.py @@ -0,0 +1,24 @@ +model = dict( + type='ImageClassifier', + backbone=dict( + type='DeiT3', + arch='s', + img_size=384, + patch_size=16, + drop_path_rate=0.0), + neck=None, + head=dict( + type='VisionTransformerClsHead', + num_classes=1000, + in_channels=384, + loss=dict( + type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'), + ), + init_cfg=[ + dict(type='TruncNormal', layer='Linear', std=.02), + dict(type='Constant', layer='LayerNorm', val=1., bias=0.), + ], + train_cfg=dict(augments=[ + dict(type='Mixup', alpha=0.8, num_classes=1000), + dict(type='CutMix', alpha=1.0, num_classes=1000) + ])) diff --git a/configs/deit3/README.md b/configs/deit3/README.md new file mode 100644 index 00000000000..859ec0aa0ee --- /dev/null +++ b/configs/deit3/README.md @@ -0,0 +1,49 @@ +# DeiT III: Revenge of the ViT + +> [DeiT III: Revenge of the ViT](https://arxiv.org/pdf/2204.07118.pdf) + + + +## Abstract + +A Vision Transformer (ViT) is a simple neural architecture amenable to serve several computer vision tasks. It has limited built-in architectural priors, in contrast to more recent architectures that incorporate priors either about the input data or of specific tasks. Recent works show that ViTs benefit from self-supervised pre-training, in particular BerT-like pre-training like BeiT. In this paper, we revisit the supervised training of ViTs. Our procedure builds upon and simplifies a recipe introduced for training ResNet-50. It includes a new simple data-augmentation procedure with only 3 augmentations, closer to the practice in self-supervised learning. Our evaluations on Image classification (ImageNet-1k with and without pre-training on ImageNet-21k), transfer learning and semantic segmentation show that our procedure outperforms by a large margin previous fully supervised training recipes for ViT. It also reveals that the performance of our ViT trained with supervision is comparable to that of more recent architectures. Our results could serve as better baselines for recent self-supervised approaches demonstrated on ViT. + +
+ +
+ +## Results and models + +### ImageNet-1k + +| Model | Pretrain | resolution | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download | +| :-------: | :----------: | :--------: | :-------: | :------: | :-------: | :-------: | :----------------------------------------------: | :------------------------------------------------------------------------------------: | +| DeiT3-S\* | From scratch | 224x224 | 22.06 | 4.61 | 81.35 | 95.31 | [config](./deit3-small-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit3/deit3-small-p16_3rdparty_in1k_20221008-0f7c70cf.pth) | +| DeiT3-S\* | From scratch | 384x384 | 22.21 | 15.52 | 83.43 | 96.68 | [config](./deit3-small-p16_64xb64_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit3/deit3-small-p16_3rdparty_in1k-384px_20221008-a2c1a0c7.pth) | +| DeiT3-S\* | ImageNet-21k | 224x224 | 22.06 | 4.61 | 83.06 | 96.77 | [config](./deit3-small-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit3/deit3-small-p16_in21k-pre_3rdparty_in1k_20221009-dcd90827.pth) | +| DeiT3-S\* | ImageNet-21k | 384x384 | 22.21 | 15.52 | 84.84 | 97.48 | [config](./deit3-small-p16_64xb64_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit3/deit3-small-p16_in21k-pre_3rdparty_in1k-384px_20221009-de116dd7.pth) | +| DeiT3-M\* | From scratch | 224x224 | 38.85 | 8.00 | 82.99 | 96.22 | [config](./deit3-medium-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit3/deit3-medium-p16_3rdparty_in1k_20221008-3b21284d.pth) | +| DeiT3-M\* | ImageNet-21k | 224x224 | 38.85 | 8.00 | 84.56 | 97.19 | [config](./deit3-medium-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit3/deit3-medium-p16_in21k-pre_3rdparty_in1k_20221009-472f11e2.pth) | +| DeiT3-B\* | From scratch | 224x224 | 86.59 | 17.58 | 83.80 | 96.55 | [config](./deit3-base-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit3/deit3-base-p16_3rdparty_in1k_20221008-60b8c8bf.pth) | +| DeiT3-B\* | From scratch | 384x384 | 86.88 | 55.54 | 85.08 | 97.25 | [config](./deit3-base-p16_64xb32_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit3/deit3-base-p16_3rdparty_in1k-384px_20221009-e19e36d4.pth) | +| DeiT3-B\* | ImageNet-21k | 224x224 | 86.59 | 17.58 | 85.70 | 97.75 | [config](./deit3-base-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit3/deit3-base-p16_in21k-pre_3rdparty_in1k_20221009-87983ca1.pth) | +| DeiT3-B\* | ImageNet-21k | 384x384 | 86.88 | 55.54 | 86.73 | 98.11 | [config](./deit3-base-p16_64xb32_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit3/deit3-base-p16_in21k-pre_3rdparty_in1k-384px_20221009-5e4e37b9.pth) | +| DeiT3-L\* | From scratch | 224x224 | 304.37 | 61.60 | 84.87 | 97.01 | [config](./deit3-large-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit3/deit3-large-p16_3rdparty_in1k_20221009-03b427ea.pth) | +| DeiT3-L\* | From scratch | 384x384 | 304.76 | 191.21 | 85.82 | 97.60 | [config](./deit3-large-p16_64xb16_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit3/deit3-large-p16_3rdparty_in1k-384px_20221009-4317ce62.pth) | +| DeiT3-L\* | ImageNet-21k | 224x224 | 304.37 | 61.60 | 86.97 | 98.24 | [config](./deit3-large-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit3/deit3-large-p16_in21k-pre_3rdparty_in1k_20221009-d8d27084.pth) | +| DeiT3-L\* | ImageNet-21k | 384x384 | 304.76 | 191.21 | 87.73 | 98.51 | [config](./deit3-large-p16_64xb16_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit3/deit3-large-p16_in21k-pre_3rdparty_in1k-384px_20221009-75fea03f.pth) | +| DeiT3-H\* | From scratch | 224x224 | 632.13 | 167.40 | 85.21 | 97.36 | [config](./deit3-huge-p14_64xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit3/deit3-huge-p14_3rdparty_in1k_20221009-e107bcb7.pth) | +| DeiT3-H\* | ImageNet-21k | 224x224 | 632.13 | 167.40 | 87.19 | 98.26 | [config](./deit3-huge-p14_64xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit3/deit3-huge-p14_in21k-pre_3rdparty_in1k_20221009-19b8a535.pth) | + +*Models with * are converted from the [official repo](https://github.com/facebookresearch/deit). The config files of these models are only for validation. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.* + +## Citation + +``` +@article{Touvron2022DeiTIR, + title={DeiT III: Revenge of the ViT}, + author={Hugo Touvron and Matthieu Cord and Herve Jegou}, + journal={arXiv preprint arXiv:2204.07118}, + year={2022}, +} +``` diff --git a/configs/deit3/deit3-base-p16_64xb32_in1k-384px.py b/configs/deit3/deit3-base-p16_64xb32_in1k-384px.py new file mode 100644 index 00000000000..b6c8a8c411e --- /dev/null +++ b/configs/deit3/deit3-base-p16_64xb32_in1k-384px.py @@ -0,0 +1,17 @@ +_base_ = [ + '../_base_/models/deit3/deit3-base-p16-384.py', + '../_base_/datasets/imagenet_bs64_deit3_384.py', + '../_base_/schedules/imagenet_bs4096_AdamW.py', + '../_base_/default_runtime.py' +] + +# dataset setting +train_dataloader = dict(batch_size=32) + +# schedule settings +optim_wrapper = dict(optimizer=dict(lr=1e-5, weight_decay=0.1)) + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (64 GPUs) x (32 samples per GPU) +auto_scale_lr = dict(base_batch_size=2048) diff --git a/configs/deit3/deit3-base-p16_64xb64_in1k.py b/configs/deit3/deit3-base-p16_64xb64_in1k.py new file mode 100644 index 00000000000..c69a64cdd06 --- /dev/null +++ b/configs/deit3/deit3-base-p16_64xb64_in1k.py @@ -0,0 +1,17 @@ +_base_ = [ + '../_base_/models/deit3/deit3-base-p16-224.py', + '../_base_/datasets/imagenet_bs64_deit3_224.py', + '../_base_/schedules/imagenet_bs4096_AdamW.py', + '../_base_/default_runtime.py' +] + +# dataset setting +train_dataloader = dict(batch_size=64) + +# schedule settings +optim_wrapper = dict(optimizer=dict(lr=1e-5, weight_decay=0.1)) + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (64 GPUs) x (64 samples per GPU) +auto_scale_lr = dict(base_batch_size=4096) diff --git a/configs/deit3/deit3-huge-p14_64xb32_in1k.py b/configs/deit3/deit3-huge-p14_64xb32_in1k.py new file mode 100644 index 00000000000..f8cae075b6a --- /dev/null +++ b/configs/deit3/deit3-huge-p14_64xb32_in1k.py @@ -0,0 +1,17 @@ +_base_ = [ + '../_base_/models/deit3/deit3-huge-p14-224.py', + '../_base_/datasets/imagenet_bs64_deit3_224.py', + '../_base_/schedules/imagenet_bs4096_AdamW.py', + '../_base_/default_runtime.py' +] + +# dataset setting +train_dataloader = dict(batch_size=32) + +# schedule settings +optim_wrapper = dict(optimizer=dict(lr=1e-5, weight_decay=0.1)) + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (64 GPUs) x (32 samples per GPU) +auto_scale_lr = dict(base_batch_size=2048) diff --git a/configs/deit3/deit3-large-p16_64xb16_in1k-384px.py b/configs/deit3/deit3-large-p16_64xb16_in1k-384px.py new file mode 100644 index 00000000000..84fb0feae63 --- /dev/null +++ b/configs/deit3/deit3-large-p16_64xb16_in1k-384px.py @@ -0,0 +1,17 @@ +_base_ = [ + '../_base_/models/deit3/deit3-large-p16-384.py', + '../_base_/datasets/imagenet_bs64_deit3_384.py', + '../_base_/schedules/imagenet_bs4096_AdamW.py', + '../_base_/default_runtime.py' +] + +# dataset setting +train_dataloader = dict(batch_size=16) + +# schedule settings +optim_wrapper = dict(optimizer=dict(lr=1e-5, weight_decay=0.1)) + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (64 GPUs) x (16 samples per GPU) +auto_scale_lr = dict(base_batch_size=1025) diff --git a/configs/deit3/deit3-large-p16_64xb64_in1k.py b/configs/deit3/deit3-large-p16_64xb64_in1k.py new file mode 100644 index 00000000000..a67ac21f9ba --- /dev/null +++ b/configs/deit3/deit3-large-p16_64xb64_in1k.py @@ -0,0 +1,17 @@ +_base_ = [ + '../_base_/models/deit3/deit3-large-p16-224.py', + '../_base_/datasets/imagenet_bs64_deit3_224.py', + '../_base_/schedules/imagenet_bs4096_AdamW.py', + '../_base_/default_runtime.py' +] + +# dataset setting +train_dataloader = dict(batch_size=64) + +# schedule settings +optim_wrapper = dict(optimizer=dict(lr=1e-5, weight_decay=0.1)) + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (64 GPUs) x (64 samples per GPU) +auto_scale_lr = dict(base_batch_size=4096) diff --git a/configs/deit3/deit3-medium-p16_64xb64_in1k.py b/configs/deit3/deit3-medium-p16_64xb64_in1k.py new file mode 100644 index 00000000000..def48e682a5 --- /dev/null +++ b/configs/deit3/deit3-medium-p16_64xb64_in1k.py @@ -0,0 +1,17 @@ +_base_ = [ + '../_base_/models/deit3/deit3-medium-p16-224.py', + '../_base_/datasets/imagenet_bs64_deit3_224.py', + '../_base_/schedules/imagenet_bs4096_AdamW.py', + '../_base_/default_runtime.py' +] + +# dataset setting +train_dataloader = dict(batch_size=64) + +# schedule settings +optim_wrapper = dict(optimizer=dict(lr=1e-5, weight_decay=0.1)) + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (64 GPUs) x (64 samples per GPU) +auto_scale_lr = dict(base_batch_size=4096) diff --git a/configs/deit3/deit3-small-p16_64xb64_in1k-384px.py b/configs/deit3/deit3-small-p16_64xb64_in1k-384px.py new file mode 100644 index 00000000000..e6b3e892c34 --- /dev/null +++ b/configs/deit3/deit3-small-p16_64xb64_in1k-384px.py @@ -0,0 +1,17 @@ +_base_ = [ + '../_base_/models/deit3/deit3-small-p16-384.py', + '../_base_/datasets/imagenet_bs64_deit3_384.py', + '../_base_/schedules/imagenet_bs4096_AdamW.py', + '../_base_/default_runtime.py' +] + +# dataset setting +train_dataloader = dict(batch_size=64) + +# schedule settings +optim_wrapper = dict(optimizer=dict(lr=1e-5, weight_decay=0.1)) + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (64 GPUs) x (64 samples per GPU) +auto_scale_lr = dict(base_batch_size=4096) diff --git a/configs/deit3/deit3-small-p16_64xb64_in1k.py b/configs/deit3/deit3-small-p16_64xb64_in1k.py new file mode 100644 index 00000000000..58b0a2f1837 --- /dev/null +++ b/configs/deit3/deit3-small-p16_64xb64_in1k.py @@ -0,0 +1,17 @@ +_base_ = [ + '../_base_/models/deit3/deit3-small-p16-224.py', + '../_base_/datasets/imagenet_bs64_deit3_224.py', + '../_base_/schedules/imagenet_bs4096_AdamW.py', + '../_base_/default_runtime.py' +] + +# dataset setting +train_dataloader = dict(batch_size=64) + +# schedule settings +optim_wrapper = dict(optimizer=dict(lr=1e-5, weight_decay=0.1)) + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (64 GPUs) x (64 samples per GPU) +auto_scale_lr = dict(base_batch_size=4096) diff --git a/configs/deit3/metafile.yml b/configs/deit3/metafile.yml new file mode 100644 index 00000000000..0376331473b --- /dev/null +++ b/configs/deit3/metafile.yml @@ -0,0 +1,310 @@ +Collections: + - Name: DeiT3 + Metadata: + Architecture: + - Attention Dropout + - Convolution + - Dense Connections + - Dropout + - GELU + - Layer Normalization + - Multi-Head Attention + - Scaled Dot-Product Attention + - Tanh Activation + Paper: + URL: https://arxiv.org/pdf/2204.07118.pdf + Title: 'DeiT III: Revenge of the ViT' + README: configs/deit3/README.md + Code: + URL: https://github.com/open-mmlab/mmclassification/blob/v1.0.0rc2/mmcls/models/backbones/deit3.py + Version: v1.0.0rc2 + +Models: + - Name: deit3-small-p16_3rdparty_in1k + In Collection: DeiT3 + Metadata: + FLOPs: 4607954304 + Parameters: 22059496 + Training Data: + - ImageNet-1k + Results: + - Dataset: ImageNet-1k + Task: Image Classification + Metrics: + Top 1 Accuracy: 81.35 + Top 5 Accuracy: 95.31 + Weights: https://download.openmmlab.com/mmclassification/v0/deit3/deit3-small-p16_3rdparty_in1k_20221008-0f7c70cf.pth + Converted From: + Weights: https://dl.fbaipublicfiles.com/deit/deit_3_small_224_1k.pth + Code: https://github.com/facebookresearch/deit/blob/main/models_v2.py#L171 + Config: configs/deit3/deit3-small-p16_64xb64_in1k.py + - Name: deit3-small-p16_3rdparty_in1k-384px + In Collection: DeiT3 + Metadata: + FLOPs: 15517663104 + Parameters: 22205416 + Training Data: + - ImageNet-1k + Results: + - Dataset: ImageNet-1k + Task: Image Classification + Metrics: + Top 1 Accuracy: 83.43 + Top 5 Accuracy: 96.68 + Weights: https://download.openmmlab.com/mmclassification/v0/deit3/deit3-small-p16_3rdparty_in1k-384px_20221008-a2c1a0c7.pth + Converted From: + Weights: https://dl.fbaipublicfiles.com/deit/deit_3_small_384_1k.pth + Code: https://github.com/facebookresearch/deit/blob/main/models_v2.py#L171 + Config: configs/deit3/deit3-small-p16_64xb64_in1k-384px.py + - Name: deit3-small-p16_in21k-pre_3rdparty_in1k + In Collection: DeiT3 + Metadata: + FLOPs: 4607954304 + Parameters: 22059496 + Training Data: + - ImageNet-21k + Results: + - Dataset: ImageNet-1k + Task: Image Classification + Metrics: + Top 1 Accuracy: 83.06 + Top 5 Accuracy: 96.77 + Weights: https://download.openmmlab.com/mmclassification/v0/deit3/deit3-small-p16_in21k-pre_3rdparty_in1k_20221009-dcd90827.pth + Converted From: + Weights: https://dl.fbaipublicfiles.com/deit/deit_3_small_224_21k.pth + Code: https://github.com/facebookresearch/deit/blob/main/models_v2.py#L171 + Config: configs/deit3/deit3-small-p16_64xb64_in1k.py + - Name: deit3-small-p16_in21k-pre_3rdparty_in1k-384px + In Collection: DeiT3 + Metadata: + FLOPs: 15517663104 + Parameters: 22205416 + Training Data: + - ImageNet-21k + Results: + - Dataset: ImageNet-1k + Task: Image Classification + Metrics: + Top 1 Accuracy: 84.84 + Top 5 Accuracy: 97.48 + Weights: https://download.openmmlab.com/mmclassification/v0/deit3/deit3-small-p16_in21k-pre_3rdparty_in1k-384px_20221009-de116dd7.pth + Converted From: + Weights: https://dl.fbaipublicfiles.com/deit/deit_3_small_384_21k.pth + Code: https://github.com/facebookresearch/deit/blob/main/models_v2.py#L171 + Config: configs/deit3/deit3-small-p16_64xb64_in1k-384px.py + - Name: deit3-medium-p16_3rdparty_in1k + In Collection: DeiT3 + Metadata: + FLOPs: 8003064320 + Parameters: 38849512 + Training Data: + - ImageNet-1k + Results: + - Dataset: ImageNet-1k + Task: Image Classification + Metrics: + Top 1 Accuracy: 82.99 + Top 5 Accuracy: 96.22 + Weights: https://download.openmmlab.com/mmclassification/v0/deit3/deit3-medium-p16_3rdparty_in1k_20221008-3b21284d.pth + Converted From: + Weights: https://dl.fbaipublicfiles.com/deit/deit_3_medium_224_1k.pth + Code: https://github.com/facebookresearch/deit/blob/main/models_v2.py#L171 + Config: configs/deit3/deit3-medium-p16_64xb64_in1k.py + - Name: deit3-medium-p16_in21k-pre_3rdparty_in1k + In Collection: DeiT3 + Metadata: + FLOPs: 8003064320 + Parameters: 38849512 + Training Data: + - ImageNet-21k + Results: + - Dataset: ImageNet-1k + Task: Image Classification + Metrics: + Top 1 Accuracy: 84.56 + Top 5 Accuracy: 97.19 + Weights: https://download.openmmlab.com/mmclassification/v0/deit3/deit3-medium-p16_in21k-pre_3rdparty_in1k_20221009-472f11e2.pth + Converted From: + Weights: https://dl.fbaipublicfiles.com/deit/deit_3_medium_224_21k.pth + Code: https://github.com/facebookresearch/deit/blob/main/models_v2.py#L171 + Config: configs/deit3/deit3-medium-p16_64xb64_in1k.py + - Name: deit3-base-p16_3rdparty_in1k + In Collection: DeiT3 + Metadata: + FLOPs: 17581972224 + Parameters: 86585320 + Training Data: + - ImageNet-1k + Results: + - Dataset: ImageNet-1k + Task: Image Classification + Metrics: + Top 1 Accuracy: 83.80 + Top 5 Accuracy: 96.55 + Weights: https://download.openmmlab.com/mmclassification/v0/deit3/deit3-base-p16_3rdparty_in1k_20221008-60b8c8bf.pth + Converted From: + Weights: https://dl.fbaipublicfiles.com/deit/deit_3_base_224_1k.pth + Code: https://github.com/facebookresearch/deit/blob/main/models_v2.py#L171 + Config: configs/deit3/deit3-base-p16_64xb64_in1k.py + - Name: deit3-base-p16_3rdparty_in1k-384px + In Collection: DeiT3 + Metadata: + FLOPs: 55538974464 + Parameters: 86877160 + Training Data: + - ImageNet-1k + Results: + - Dataset: ImageNet-1k + Task: Image Classification + Metrics: + Top 1 Accuracy: 85.08 + Top 5 Accuracy: 97.25 + Weights: https://download.openmmlab.com/mmclassification/v0/deit3/deit3-base-p16_3rdparty_in1k-384px_20221009-e19e36d4.pth + Converted From: + Weights: https://dl.fbaipublicfiles.com/deit/deit_3_base_384_1k.pth + Code: https://github.com/facebookresearch/deit/blob/main/models_v2.py#L171 + Config: configs/deit3/deit3-base-p16_64xb32_in1k-384px.py + - Name: deit3-base-p16_in21k-pre_3rdparty_in1k + In Collection: DeiT3 + Metadata: + FLOPs: 17581972224 + Parameters: 86585320 + Training Data: + - ImageNet-21k + Results: + - Dataset: ImageNet-1k + Task: Image Classification + Metrics: + Top 1 Accuracy: 85.70 + Top 5 Accuracy: 97.75 + Weights: https://download.openmmlab.com/mmclassification/v0/deit3/deit3-base-p16_in21k-pre_3rdparty_in1k_20221009-87983ca1.pth + Converted From: + Weights: https://dl.fbaipublicfiles.com/deit/deit_3_base_224_21k.pth + Code: https://github.com/facebookresearch/deit/blob/main/models_v2.py#L171 + Config: configs/deit3/deit3-base-p16_64xb64_in1k.py + - Name: deit3-base-p16_in21k-pre_3rdparty_in1k-384px + In Collection: DeiT3 + Metadata: + FLOPs: 55538974464 + Parameters: 86877160 + Training Data: + - ImageNet-21k + Results: + - Dataset: ImageNet-1k + Task: Image Classification + Metrics: + Top 1 Accuracy: 86.73 + Top 5 Accuracy: 98.11 + Weights: https://download.openmmlab.com/mmclassification/v0/deit3/deit3-base-p16_in21k-pre_3rdparty_in1k-384px_20221009-5e4e37b9.pth + Converted From: + Weights: https://dl.fbaipublicfiles.com/deit/deit_3_base_384_21k.pth + Code: https://github.com/facebookresearch/deit/blob/main/models_v2.py#L171 + Config: configs/deit3/deit3-base-p16_64xb32_in1k-384px.py + - Name: deit3-large-p16_3rdparty_in1k + In Collection: DeiT3 + Metadata: + FLOPs: 61603111936 + Parameters: 304374760 + Training Data: + - ImageNet-1k + Results: + - Dataset: ImageNet-1k + Task: Image Classification + Metrics: + Top 1 Accuracy: 84.87 + Top 5 Accuracy: 97.01 + Weights: https://download.openmmlab.com/mmclassification/v0/deit3/deit3-large-p16_3rdparty_in1k_20221009-03b427ea.pth + Converted From: + Weights: https://dl.fbaipublicfiles.com/deit/deit_3_large_224_1k.pth + Code: https://github.com/facebookresearch/deit/blob/main/models_v2.py#L171 + Config: configs/deit3/deit3-large-p16_64xb64_in1k.py + - Name: deit3-large-p16_3rdparty_in1k-384px + In Collection: DeiT3 + Metadata: + FLOPs: 191210034176 + Parameters: 304763880 + Training Data: + - ImageNet-1k + Results: + - Dataset: ImageNet-1k + Task: Image Classification + Metrics: + Top 1 Accuracy: 85.82 + Top 5 Accuracy: 97.60 + Weights: https://download.openmmlab.com/mmclassification/v0/deit3/deit3-large-p16_3rdparty_in1k-384px_20221009-4317ce62.pth + Converted From: + Weights: https://dl.fbaipublicfiles.com/deit/deit_3_large_384_1k.pth + Code: https://github.com/facebookresearch/deit/blob/main/models_v2.py#L171 + Config: configs/deit3/deit3-large-p16_64xb16_in1k-384px.py + - Name: deit3-large-p16_in21k-pre_3rdparty_in1k + In Collection: DeiT3 + Metadata: + FLOPs: 61603111936 + Parameters: 304374760 + Training Data: + - ImageNet-21k + Results: + - Dataset: ImageNet-1k + Task: Image Classification + Metrics: + Top 1 Accuracy: 86.97 + Top 5 Accuracy: 98.24 + Weights: https://download.openmmlab.com/mmclassification/v0/deit3/deit3-large-p16_in21k-pre_3rdparty_in1k_20221009-d8d27084.pth + Converted From: + Weights: https://dl.fbaipublicfiles.com/deit/deit_3_large_224_21k.pth + Code: https://github.com/facebookresearch/deit/blob/main/models_v2.py#L171 + Config: configs/deit3/deit3-large-p16_64xb64_in1k.py + - Name: deit3-large-p16_in21k-pre_3rdparty_in1k-384px + In Collection: DeiT3 + Metadata: + FLOPs: 191210034176 + Parameters: 304763880 + Training Data: + - ImageNet-21k + Results: + - Dataset: ImageNet-1k + Task: Image Classification + Metrics: + Top 1 Accuracy: 87.73 + Top 5 Accuracy: 98.51 + Weights: https://download.openmmlab.com/mmclassification/v0/deit3/deit3-large-p16_in21k-pre_3rdparty_in1k-384px_20221009-75fea03f.pth + Converted From: + Weights: https://dl.fbaipublicfiles.com/deit/deit_3_large_384_21k.pth + Code: https://github.com/facebookresearch/deit/blob/main/models_v2.py#L171 + Config: configs/deit3/deit3-large-p16_64xb16_in1k-384px.py + - Name: deit3-huge-p14_3rdparty_in1k + In Collection: DeiT3 + Metadata: + FLOPs: 167400741120 + Parameters: 632126440 + Training Data: + - ImageNet-1k + Results: + - Dataset: ImageNet-1k + Task: Image Classification + Metrics: + Top 1 Accuracy: 85.21 + Top 5 Accuracy: 97.36 + Weights: https://download.openmmlab.com/mmclassification/v0/deit3/deit3-huge-p14_3rdparty_in1k_20221009-e107bcb7.pth + Converted From: + Weights: https://dl.fbaipublicfiles.com/deit/deit_3_huge_224_1k.pth + Code: https://github.com/facebookresearch/deit/blob/main/models_v2.py#L171 + Config: configs/deit3/deit3-huge-p14_64xb32_in1k.py + - Name: deit3-huge-p14_in21k-pre_3rdparty_in1k + In Collection: DeiT3 + Metadata: + FLOPs: 167400741120 + Parameters: 632126440 + Training Data: + - ImageNet-21k + Results: + - Dataset: ImageNet-1k + Task: Image Classification + Metrics: + Top 1 Accuracy: 87.19 + Top 5 Accuracy: 98.26 + Weights: https://download.openmmlab.com/mmclassification/v0/deit3/deit3-huge-p14_in21k-pre_3rdparty_in1k_20221009-19b8a535.pth + Converted From: + Weights: https://dl.fbaipublicfiles.com/deit/deit_3_huge_224_1k.pth + Code: https://github.com/facebookresearch/deit/blob/main/models_v2.py#L171 + Config: configs/deit3/deit3-huge-p14_64xb32_in1k.py diff --git a/docs/en/api/models.rst b/docs/en/api/models.rst index d188805b1f0..62b05fa983f 100644 --- a/docs/en/api/models.rst +++ b/docs/en/api/models.rst @@ -64,6 +64,7 @@ Backbones ConvMixer ConvNeXt DenseNet + DeiT3 DistilledVisionTransformer EfficientFormer EfficientNet diff --git a/mmcls/models/backbones/__init__.py b/mmcls/models/backbones/__init__.py index 8ed11e853da..e9b1a34ad80 100644 --- a/mmcls/models/backbones/__init__.py +++ b/mmcls/models/backbones/__init__.py @@ -5,6 +5,7 @@ from .convnext import ConvNeXt from .cspnet import CSPDarkNet, CSPNet, CSPResNet, CSPResNeXt from .deit import DistilledVisionTransformer +from .deit3 import DeiT3 from .densenet import DenseNet from .edgenext import EdgeNeXt from .efficientformer import EfficientFormer @@ -87,4 +88,5 @@ 'EfficientFormer', 'SwinTransformerV2', 'MViT', + 'DeiT3', ] diff --git a/mmcls/models/backbones/deit3.py b/mmcls/models/backbones/deit3.py new file mode 100644 index 00000000000..5361d30a5f4 --- /dev/null +++ b/mmcls/models/backbones/deit3.py @@ -0,0 +1,443 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Sequence + +import numpy as np +import torch +from mmcv.cnn import Linear, build_activation_layer, build_norm_layer +from mmcv.cnn.bricks.drop import build_dropout +from mmcv.cnn.bricks.transformer import PatchEmbed +from mmengine.model import BaseModule, ModuleList, Sequential +from mmengine.utils import deprecated_api_warning +from torch import nn + +from mmcls.registry import MODELS +from ..utils import LayerScale, MultiheadAttention, resize_pos_embed, to_2tuple +from .vision_transformer import VisionTransformer + + +class DeiT3FFN(BaseModule): + """FFN for DeiT3. + + The differences between DeiT3FFN & FFN: + 1. Use LayerScale. + + Args: + embed_dims (int): The feature dimension. Same as + `MultiheadAttention`. Defaults: 256. + feedforward_channels (int): The hidden dimension of FFNs. + Defaults: 1024. + num_fcs (int, optional): The number of fully-connected layers in + FFNs. Default: 2. + act_cfg (dict, optional): The activation config for FFNs. + Default: dict(type='ReLU') + ffn_drop (float, optional): Probability of an element to be + zeroed in FFN. Default 0.0. + add_identity (bool, optional): Whether to add the + identity connection. Default: `True`. + dropout_layer (obj:`ConfigDict`): The dropout_layer used + when adding the shortcut. + use_layer_scale (bool): Whether to use layer_scale in + DeiT3FFN. Defaults to True. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + @deprecated_api_warning( + { + 'dropout': 'ffn_drop', + 'add_residual': 'add_identity' + }, + cls_name='FFN') + def __init__(self, + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + act_cfg=dict(type='ReLU', inplace=True), + ffn_drop=0., + dropout_layer=None, + add_identity=True, + use_layer_scale=True, + init_cfg=None, + **kwargs): + super().__init__(init_cfg) + assert num_fcs >= 2, 'num_fcs should be no less ' \ + f'than 2. got {num_fcs}.' + self.embed_dims = embed_dims + self.feedforward_channels = feedforward_channels + self.num_fcs = num_fcs + self.act_cfg = act_cfg + self.activate = build_activation_layer(act_cfg) + + layers = [] + in_channels = embed_dims + for _ in range(num_fcs - 1): + layers.append( + Sequential( + Linear(in_channels, feedforward_channels), self.activate, + nn.Dropout(ffn_drop))) + in_channels = feedforward_channels + layers.append(Linear(feedforward_channels, embed_dims)) + layers.append(nn.Dropout(ffn_drop)) + self.layers = Sequential(*layers) + self.dropout_layer = build_dropout( + dropout_layer) if dropout_layer else torch.nn.Identity() + self.add_identity = add_identity + + if use_layer_scale: + self.gamma2 = LayerScale(embed_dims) + else: + self.gamma2 = nn.Identity() + + @deprecated_api_warning({'residual': 'identity'}, cls_name='FFN') + def forward(self, x, identity=None): + """Forward function for `FFN`. + + The function would add x to the output tensor if residue is None. + """ + out = self.layers(x) + out = self.gamma2(out) + if not self.add_identity: + return self.dropout_layer(out) + if identity is None: + identity = x + return identity + self.dropout_layer(out) + + +class DeiT3TransformerEncoderLayer(BaseModule): + """Implements one encoder layer in DeiT3. + + The differences between DeiT3TransformerEncoderLayer & + TransformerEncoderLayer: + 1. Use LayerScale. + + Args: + embed_dims (int): The feature dimension + num_heads (int): Parallel attention heads + feedforward_channels (int): The hidden dimension for FFNs + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Defaults to 0. + attn_drop_rate (float): The drop out rate for attention output weights. + Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + num_fcs (int): The number of fully-connected layers for FFNs. + Defaults to 2. + qkv_bias (bool): enable bias for qkv if True. Defaults to True. + use_layer_scale (bool): Whether to use layer_scale in + DeiT3TransformerEncoderLayer. Defaults to True. + act_cfg (dict): The activation config for FFNs. + Defaluts to ``dict(type='GELU')``. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + num_fcs=2, + qkv_bias=True, + use_layer_scale=True, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + init_cfg=None): + super(DeiT3TransformerEncoderLayer, self).__init__(init_cfg=init_cfg) + + self.embed_dims = embed_dims + + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, self.embed_dims, postfix=1) + self.add_module(self.norm1_name, norm1) + + self.attn = MultiheadAttention( + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=attn_drop_rate, + proj_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + qkv_bias=qkv_bias, + use_layer_scale=use_layer_scale) + + self.norm2_name, norm2 = build_norm_layer( + norm_cfg, self.embed_dims, postfix=2) + self.add_module(self.norm2_name, norm2) + + self.ffn = DeiT3FFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=num_fcs, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg, + use_layer_scale=use_layer_scale) + + @property + def norm1(self): + return getattr(self, self.norm1_name) + + @property + def norm2(self): + return getattr(self, self.norm2_name) + + def init_weights(self): + super(DeiT3TransformerEncoderLayer, self).init_weights() + for m in self.ffn.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + nn.init.normal_(m.bias, std=1e-6) + + def forward(self, x): + x = x + self.attn(self.norm1(x)) + x = self.ffn(self.norm2(x), identity=x) + return x + + +@MODELS.register_module() +class DeiT3(VisionTransformer): + """DeiT3 backbone. + + A PyTorch implement of : `DeiT III: Revenge of the ViT + `_ + + The differences between DeiT3 & VisionTransformer: + + 1. Use LayerScale. + 2. Concat cls token after adding pos_embed. + + Args: + arch (str | dict): DeiT3 architecture. If use string, + choose from 'small', 'base', 'medium', 'large' and 'huge'. + If use dict, it should have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **num_layers** (int): The number of transformer encoder layers. + - **num_heads** (int): The number of heads in attention modules. + - **feedforward_channels** (int): The hidden dimensions in + feedforward modules. + + Defaults to 'base'. + img_size (int | tuple): The expected input image shape. Because we + support dynamic input shape, just set the argument to the most + common input image shape. Defaults to 224. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 16. + in_channels (int): The num of input channels. Defaults to 3. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Defaults to True. + with_cls_token (bool): Whether concatenating class token into image + tokens as transformer input. Defaults to True. + output_cls_token (bool): Whether output the cls_token. If set True, + ``with_cls_token`` must be True. Defaults to True. + use_layer_scale (bool): Whether to use layer_scale in DeiT3. + Defaults to True. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Defaults to "bicubic". + patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + arch_zoo = { + **dict.fromkeys( + ['s', 'small'], { + 'embed_dims': 384, + 'num_layers': 12, + 'num_heads': 6, + 'feedforward_channels': 1536, + }), + **dict.fromkeys( + ['m', 'medium'], { + 'embed_dims': 512, + 'num_layers': 12, + 'num_heads': 8, + 'feedforward_channels': 2048, + }), + **dict.fromkeys( + ['b', 'base'], { + 'embed_dims': 768, + 'num_layers': 12, + 'num_heads': 12, + 'feedforward_channels': 3072 + }), + **dict.fromkeys( + ['l', 'large'], { + 'embed_dims': 1024, + 'num_layers': 24, + 'num_heads': 16, + 'feedforward_channels': 4096 + }), + **dict.fromkeys( + ['h', 'huge'], { + 'embed_dims': 1280, + 'num_layers': 32, + 'num_heads': 16, + 'feedforward_channels': 5120 + }), + } + # not using num_extra_tokens in deit3 because adding cls tokens after + # adding pos_embed + num_extra_tokens = 0 + + def __init__(self, + arch='base', + img_size=224, + patch_size=16, + in_channels=3, + out_indices=-1, + drop_rate=0., + drop_path_rate=0., + qkv_bias=True, + norm_cfg=dict(type='LN', eps=1e-6), + final_norm=True, + with_cls_token=True, + output_cls_token=True, + use_layer_scale=True, + interpolate_mode='bicubic', + patch_cfg=dict(), + layer_cfgs=dict(), + init_cfg=None): + super(VisionTransformer, self).__init__(init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = { + 'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels' + } + assert isinstance(arch, dict) and essential_keys <= set(arch), \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.embed_dims = self.arch_settings['embed_dims'] + self.num_layers = self.arch_settings['num_layers'] + self.img_size = to_2tuple(img_size) + + # Set patch embedding + _patch_cfg = dict( + in_channels=in_channels, + input_size=img_size, + embed_dims=self.embed_dims, + conv_type='Conv2d', + kernel_size=patch_size, + stride=patch_size, + ) + _patch_cfg.update(patch_cfg) + self.patch_embed = PatchEmbed(**_patch_cfg) + self.patch_resolution = self.patch_embed.init_out_size + num_patches = self.patch_resolution[0] * self.patch_resolution[1] + + # Set cls token + if output_cls_token: + assert with_cls_token is True, f'with_cls_token must be True if' \ + f'set output_cls_token to True, but got {with_cls_token}' + self.with_cls_token = with_cls_token + self.output_cls_token = output_cls_token + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + + # Set position embedding + self.interpolate_mode = interpolate_mode + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches, self.embed_dims)) + self._register_load_state_dict_pre_hook(self._prepare_pos_embed) + + self.drop_after_pos = nn.Dropout(p=drop_rate) + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = self.num_layers + index + assert 0 <= out_indices[i] <= self.num_layers, \ + f'Invalid out_indices {index}' + self.out_indices = out_indices + + # stochastic depth decay rule + dpr = np.linspace(0, drop_path_rate, self.num_layers) + + self.layers = ModuleList() + if isinstance(layer_cfgs, dict): + layer_cfgs = [layer_cfgs] * self.num_layers + for i in range(self.num_layers): + _layer_cfg = dict( + embed_dims=self.embed_dims, + num_heads=self.arch_settings['num_heads'], + feedforward_channels=self. + arch_settings['feedforward_channels'], + drop_rate=drop_rate, + drop_path_rate=dpr[i], + qkv_bias=qkv_bias, + norm_cfg=norm_cfg, + use_layer_scale=use_layer_scale) + _layer_cfg.update(layer_cfgs[i]) + self.layers.append(DeiT3TransformerEncoderLayer(**_layer_cfg)) + + self.final_norm = final_norm + if final_norm: + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, self.embed_dims, postfix=1) + self.add_module(self.norm1_name, norm1) + + def forward(self, x): + B = x.shape[0] + x, patch_resolution = self.patch_embed(x) + + x = x + resize_pos_embed( + self.pos_embed, + self.patch_resolution, + patch_resolution, + mode=self.interpolate_mode, + num_extra_tokens=self.num_extra_tokens) + x = self.drop_after_pos(x) + + # stole cls_tokens impl from Phil Wang, thanks + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + if not self.with_cls_token: + # Remove class token for transformer encoder input + x = x[:, 1:] + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + + if i == len(self.layers) - 1 and self.final_norm: + x = self.norm1(x) + + if i in self.out_indices: + B, _, C = x.shape + if self.with_cls_token: + patch_token = x[:, 1:].reshape(B, *patch_resolution, C) + patch_token = patch_token.permute(0, 3, 1, 2) + cls_token = x[:, 0] + else: + patch_token = x.reshape(B, *patch_resolution, C) + patch_token = patch_token.permute(0, 3, 1, 2) + cls_token = None + if self.output_cls_token: + out = [patch_token, cls_token] + else: + out = patch_token + outs.append(out) + + return tuple(outs) diff --git a/mmcls/models/utils/attention.py b/mmcls/models/utils/attention.py index 064ec388211..74bd1bd5505 100644 --- a/mmcls/models/utils/attention.py +++ b/mmcls/models/utils/attention.py @@ -10,6 +10,7 @@ from mmcls.registry import MODELS from .helpers import to_2tuple +from .layer_scale import LayerScale # After pytorch v1.10.0, use torch.meshgrid without indexing # will raise extra warning. For more details, @@ -511,6 +512,7 @@ def __init__(self, qk_scale=None, proj_bias=True, v_shortcut=False, + use_layer_scale=False, init_cfg=None): super(MultiheadAttention, self).__init__(init_cfg=init_cfg) @@ -529,6 +531,11 @@ def __init__(self, self.out_drop = build_dropout(dropout_layer) + if use_layer_scale: + self.gamma1 = LayerScale(embed_dims) + else: + self.gamma1 = nn.Identity() + def forward(self, x): B, N, _ = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, @@ -541,7 +548,7 @@ def forward(self, x): x = (attn @ v).transpose(1, 2).reshape(B, N, self.embed_dims) x = self.proj(x) - x = self.out_drop(self.proj_drop(x)) + x = self.out_drop(self.gamma1(self.proj_drop(x))) if self.v_shortcut: x = v.squeeze(1) + x diff --git a/model-index.yml b/model-index.yml index b1fa357f0e1..72ea0381d40 100644 --- a/model-index.yml +++ b/model-index.yml @@ -33,3 +33,4 @@ Import: - configs/mobileone/metafile.yml - configs/efficientformer/metafile.yml - configs/swin_transformer_v2/metafile.yml + - configs/deit3/metafile.yml diff --git a/tests/test_models/test_backbones/test_deit3.py b/tests/test_models/test_backbones/test_deit3.py new file mode 100644 index 00000000000..7d7d485e418 --- /dev/null +++ b/tests/test_models/test_backbones/test_deit3.py @@ -0,0 +1,179 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +import os +import tempfile +from copy import deepcopy +from unittest import TestCase + +import torch +from mmengine.runner import load_checkpoint, save_checkpoint + +from mmcls.models.backbones import DeiT3 + + +class TestDeiT3(TestCase): + + def setUp(self): + self.cfg = dict( + arch='b', img_size=224, patch_size=16, drop_path_rate=0.1) + + def test_structure(self): + # Test invalid default arch + with self.assertRaisesRegex(AssertionError, 'not in default archs'): + cfg = deepcopy(self.cfg) + cfg['arch'] = 'unknown' + DeiT3(**cfg) + + # Test invalid custom arch + with self.assertRaisesRegex(AssertionError, 'Custom arch needs'): + cfg = deepcopy(self.cfg) + cfg['arch'] = { + 'num_layers': 24, + 'num_heads': 16, + 'feedforward_channels': 4096 + } + DeiT3(**cfg) + + # Test custom arch + cfg = deepcopy(self.cfg) + cfg['arch'] = { + 'embed_dims': 128, + 'num_layers': 24, + 'num_heads': 16, + 'feedforward_channels': 1024 + } + model = DeiT3(**cfg) + self.assertEqual(model.embed_dims, 128) + self.assertEqual(model.num_layers, 24) + for layer in model.layers: + self.assertEqual(layer.attn.num_heads, 16) + self.assertEqual(layer.ffn.feedforward_channels, 1024) + + # Test out_indices + cfg = deepcopy(self.cfg) + cfg['out_indices'] = {1: 1} + with self.assertRaisesRegex(AssertionError, "get "): + DeiT3(**cfg) + cfg['out_indices'] = [0, 13] + with self.assertRaisesRegex(AssertionError, 'Invalid out_indices 13'): + DeiT3(**cfg) + + # Test model structure + cfg = deepcopy(self.cfg) + model = DeiT3(**cfg) + self.assertEqual(len(model.layers), 12) + dpr_inc = 0.1 / (12 - 1) + dpr = 0 + for layer in model.layers: + self.assertEqual(layer.attn.embed_dims, 768) + self.assertEqual(layer.attn.num_heads, 12) + self.assertEqual(layer.ffn.feedforward_channels, 3072) + self.assertAlmostEqual(layer.attn.out_drop.drop_prob, dpr) + self.assertAlmostEqual(layer.ffn.dropout_layer.drop_prob, dpr) + dpr += dpr_inc + + def test_init_weights(self): + # test weight init cfg + cfg = deepcopy(self.cfg) + cfg['init_cfg'] = [ + dict( + type='Kaiming', + layer='Conv2d', + mode='fan_in', + nonlinearity='linear') + ] + model = DeiT3(**cfg) + ori_weight = model.patch_embed.projection.weight.clone().detach() + # The pos_embed is all zero before initialize + self.assertTrue(torch.allclose(model.pos_embed, torch.tensor(0.))) + + model.init_weights() + initialized_weight = model.patch_embed.projection.weight + self.assertFalse(torch.allclose(ori_weight, initialized_weight)) + self.assertFalse(torch.allclose(model.pos_embed, torch.tensor(0.))) + + # test load checkpoint + pretrain_pos_embed = model.pos_embed.clone().detach() + tmpdir = tempfile.gettempdir() + checkpoint = os.path.join(tmpdir, 'test.pth') + save_checkpoint(model.state_dict(), checkpoint) + cfg = deepcopy(self.cfg) + model = DeiT3(**cfg) + load_checkpoint(model, checkpoint, strict=True) + self.assertTrue(torch.allclose(model.pos_embed, pretrain_pos_embed)) + + # test load checkpoint with different img_size + cfg = deepcopy(self.cfg) + cfg['img_size'] = 384 + model = DeiT3(**cfg) + load_checkpoint(model, checkpoint, strict=True) + + os.remove(checkpoint) + + def test_forward(self): + imgs = torch.randn(1, 3, 224, 224) + + # test with_cls_token=False + cfg = deepcopy(self.cfg) + cfg['with_cls_token'] = False + cfg['output_cls_token'] = True + with self.assertRaisesRegex(AssertionError, 'but got False'): + DeiT3(**cfg) + + cfg = deepcopy(self.cfg) + cfg['with_cls_token'] = False + cfg['output_cls_token'] = False + model = DeiT3(**cfg) + outs = model(imgs) + self.assertIsInstance(outs, tuple) + self.assertEqual(len(outs), 1) + patch_token = outs[-1] + self.assertEqual(patch_token.shape, (1, 768, 14, 14)) + + # test with output_cls_token + cfg = deepcopy(self.cfg) + model = DeiT3(**cfg) + outs = model(imgs) + self.assertIsInstance(outs, tuple) + self.assertEqual(len(outs), 1) + patch_token, cls_token = outs[-1] + self.assertEqual(patch_token.shape, (1, 768, 14, 14)) + self.assertEqual(cls_token.shape, (1, 768)) + + # test without output_cls_token + cfg = deepcopy(self.cfg) + cfg['output_cls_token'] = False + model = DeiT3(**cfg) + outs = model(imgs) + self.assertIsInstance(outs, tuple) + self.assertEqual(len(outs), 1) + patch_token = outs[-1] + self.assertEqual(patch_token.shape, (1, 768, 14, 14)) + + # Test forward with multi out indices + cfg = deepcopy(self.cfg) + cfg['out_indices'] = [-3, -2, -1] + model = DeiT3(**cfg) + outs = model(imgs) + self.assertIsInstance(outs, tuple) + self.assertEqual(len(outs), 3) + for out in outs: + patch_token, cls_token = out + self.assertEqual(patch_token.shape, (1, 768, 14, 14)) + self.assertEqual(cls_token.shape, (1, 768)) + + # Test forward with dynamic input size + imgs1 = torch.randn(1, 3, 224, 224) + imgs2 = torch.randn(1, 3, 256, 256) + imgs3 = torch.randn(1, 3, 256, 309) + cfg = deepcopy(self.cfg) + model = DeiT3(**cfg) + for imgs in [imgs1, imgs2, imgs3]: + outs = model(imgs) + self.assertIsInstance(outs, tuple) + self.assertEqual(len(outs), 1) + patch_token, cls_token = outs[-1] + expect_feat_shape = (math.ceil(imgs.shape[2] / 16), + math.ceil(imgs.shape[3] / 16)) + self.assertEqual(patch_token.shape, (1, 768, *expect_feat_shape)) + self.assertEqual(cls_token.shape, (1, 768)) diff --git a/tools/model_converters/deit3_to_mmcls.py b/tools/model_converters/deit3_to_mmcls.py new file mode 100644 index 00000000000..73427870e52 --- /dev/null +++ b/tools/model_converters/deit3_to_mmcls.py @@ -0,0 +1,74 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp +from collections import OrderedDict + +import mmengine +import torch +from mmengine.runner import CheckpointLoader + + +def convert_deit3(ckpt): + + new_ckpt = OrderedDict() + + for k, v in list(ckpt.items()): + new_v = v + if k.startswith('head'): + new_k = k.replace('head.', 'head.layers.head.') + new_ckpt[new_k] = new_v + continue + elif k.startswith('patch_embed'): + if 'proj.' in k: + new_k = k.replace('proj.', 'projection.') + else: + new_k = k + elif k.startswith('blocks'): + new_k = k.replace('blocks.', 'layers.') + if 'norm1' in k: + new_k = new_k.replace('norm1', 'ln1') + elif 'norm2' in k: + new_k = new_k.replace('norm2', 'ln2') + elif 'mlp.fc1' in k: + new_k = new_k.replace('mlp.fc1', 'ffn.layers.0.0') + elif 'mlp.fc2' in k: + new_k = new_k.replace('mlp.fc2', 'ffn.layers.1') + elif 'gamma_1' in k: + new_k = new_k.replace('gamma_1', 'attn.gamma1.weight') + elif 'gamma_2' in k: + new_k = new_k.replace('gamma_2', 'ffn.gamma2.weight') + elif k.startswith('norm'): + new_k = k.replace('norm', 'ln1') + else: + new_k = k + + if not new_k.startswith('head'): + new_k = 'backbone.' + new_k + new_ckpt[new_k] = new_v + return new_ckpt + + +def main(): + parser = argparse.ArgumentParser( + description='Convert keys in pretrained van models to mmcls style.') + parser.add_argument('src', help='src model path or url') + # The dst path must be a full path of the new checkpoint. + parser.add_argument('dst', help='save path') + args = parser.parse_args() + + checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') + + if 'model' in checkpoint: + state_dict = checkpoint['model'] + else: + state_dict = checkpoint + + weight = convert_deit3(state_dict) + mmengine.mkdir_or_exist(osp.dirname(args.dst)) + torch.save(weight, args.dst) + + print('Done!!') + + +if __name__ == '__main__': + main()