Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support VAN #739

Merged
merged 11 commits into from
Apr 28, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ venv.bak/
*.log.json
/work_dirs
/mmcls/.mim
.DS_Store

# Pytorch
*.pth
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ Results and models are available in the [model zoo](https://mmclassification.rea
- [x] [EfficientNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/efficientnet)
- [x] [ConvNeXt](https://github.com/open-mmlab/mmclassification/tree/master/configs/convnext)
- [x] [HRNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/hrnet)
- [x] [VAN](https://github.com/open-mmlab/mmclassification/tree/master/configs/van)

</details>

Expand Down
1 change: 1 addition & 0 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ MMClassification 是一款基于 PyTorch 的开源图像分类工具箱,是 [O
- [x] [EfficientNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/efficientnet)
- [x] [ConvNeXt](https://github.com/open-mmlab/mmclassification/tree/master/configs/convnext)
- [x] [HRNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/hrnet)
- [x] [VAN](https://github.com/open-mmlab/mmclassification/tree/master/configs/van)

</details>

Expand Down
13 changes: 13 additions & 0 deletions configs/_base_/models/van/van_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(type='VAN', arch='base', drop_path_rate=0.1),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=512,
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
cal_acc=False))
13 changes: 13 additions & 0 deletions configs/_base_/models/van/van_large.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(type='VAN', arch='large', drop_path_rate=0.2),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=512,
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
cal_acc=False))
21 changes: 21 additions & 0 deletions configs/_base_/models/van/van_small.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(type='VAN', arch='small', drop_path_rate=0.1),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=512,
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
cal_acc=False),
init_cfg=[
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
],
train_cfg=dict(augments=[
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
]))
21 changes: 21 additions & 0 deletions configs/_base_/models/van/van_tiny.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(type='VAN', arch='tiny', drop_path_rate=0.1),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=256,
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
cal_acc=False),
init_cfg=[
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
],
train_cfg=dict(augments=[
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
]))
39 changes: 39 additions & 0 deletions configs/van/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Visual Attention Network

> [Visual Attention Network](https://arxiv.org/pdf/2202.09741v2.pdf)
<!-- [ALGORITHM] -->

## Abstract

While originally designed for natural language processing (NLP) tasks, the self-attention mechanism has recently taken various computer vision areas by storm. However, the 2D nature of images brings three challenges for applying self-attention in computer vision. (1) Treating images as 1D sequences neglects their 2D structures. (2) The quadratic complexity is too expensive for high-resolution images. (3) It only captures spatial adaptability but ignores channel adaptability. In this paper, we propose a novel large kernel attention (LKA) module to enable self-adaptive and long-range correlations in self-attention while avoiding the above issues. We further introduce a novel neural network based on LKA, namely Visual Attention Network (VAN). While extremely simple and efficient, VAN outperforms the state-of-the-art vision transformers and convolutional neural networks with a large margin in extensive experiments, including image classification, object detection, semantic segmentation, instance segmentation, etc.

<div align=center>
<img src="https://user-images.githubusercontent.com/24734142/157409484-f26fcc1f-a856-48c2-a7a7-d157c38877ac.png" width="90%"/>
</div>

<div align=center>
okotaku marked this conversation as resolved.
Show resolved Hide resolved
<img src="https://user-images.githubusercontent.com/24734142/157409411-2f622ba7-553c-4702-91be-eba03f9ea04f.png" width="90%"/>
</div>


## Results and models

### ImageNet-1k

| Model | Pretrain | resolution | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
|:---------:|:------------:|:-----------:|:---------:|:---------:|:---------:|:---------:|:------:|:--------:|
| VAN-T | From scratch | 224x224 | 4.1 | 0.9 | 75.4 | | | |
| VAN-S | From scratch | 224x224 | 13.9 | 2.5 | 81.1 | | | |
| VAN-B | From scratch | 224x224 | 26.6 | 5.0 | 82.8 | | | |
| VAN-L | From scratch | 224x224 | 44.8 | 9.0 | 83.9 | | | |

## Citation

```
@article{guo2022visual,
title={Visual Attention Network},
author={Guo, Meng-Hao and Lu, Cheng-Ze and Liu, Zheng-Ning and Cheng, Ming-Ming and Hu, Shi-Min},
journal={arXiv preprint arXiv:2202.09741},
year={2022}
}
```
73 changes: 73 additions & 0 deletions configs/van/metafile.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
Collections:
- Name: Visual-Attention-Network
Metadata:
Training Data: ImageNet-1k
Training Techniques:
- AdamW
- Weight Decay
Training Resources:
Epochs: 300
Batch Size: 1024
Architecture:
- Visual Attention Network
Paper:
URL: https://arxiv.org/pdf/2202.09741v2.pdf
Title: "Visual Attention Network"
README: configs/van/README.md
Code:
URL:
Version:

Models:
- Name: van-tiny_8xb128_in1k
Metadata:
FLOPs:
Parameters:
In Collection: Visual-Attention-Network
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 75.4
Top 5 Accuracy:
Task: Image Classification
Weights:
Config: configs/van/van-tiny_8xb128_in1k.py
- Name: van-small_8xb128_in1k
Metadata:
FLOPs:
Parameters:
In Collection: Visual-Attention-Network
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 81.1
Top 5 Accuracy:
Task: Image Classification
Weights:
Config: configs/van/van-small_8xb128_in1k.py
- Name: van-base_8xb128_in1k
Metadata:
FLOPs:
Parameters:
In Collection: Visual-Attention-Network
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 82.8
Top 5 Accuracy:
Task: Image Classification
Weights:
Config: configs/van/van-base_8xb128_in1k.py
- Name: van-large_8xb128_in1k
Metadata:
FLOPs:
Parameters:
In Collection: Visual-Attention-Network
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 83.9
Top 5 Accuracy:
Task: Image Classification
Weights:
Config: configs/van/van-large_8xb128_in1k.py
56 changes: 56 additions & 0 deletions configs/van/van-base-8xb128_in1k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
_base_ = [
'../_base_/models/van/van_base.py',
'../_base_/datasets/imagenet_bs64_swin_224.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py'
]

img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
size=224,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(
type='RandAugment',
policies={{_base_.rand_increasing_policies}},
num_policies=2,
total_level=10,
magnitude_level=9,
magnitude_std=0.5,
hparams=dict(
pad_val=[round(x) for x in img_norm_cfg['mean'][::-1]],
interpolation='bicubic')),
dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4),
dict(
type='RandomErasing',
erase_prob=0.25,
mode='rand',
min_area_ratio=0.02,
max_area_ratio=1 / 3,
fill_color=img_norm_cfg['mean'][::-1],
fill_std=img_norm_cfg['std'][::-1]),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='Resize',
size=(248, -1),
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=224),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]

data = dict(samples_per_gpu=128)
56 changes: 56 additions & 0 deletions configs/van/van-large-8xb128_in1k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
_base_ = [
'../_base_/models/van/van_large.py',
'../_base_/datasets/imagenet_bs64_swin_224.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py'
]

img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
size=224,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(
type='RandAugment',
policies={{_base_.rand_increasing_policies}},
num_policies=2,
total_level=10,
magnitude_level=9,
magnitude_std=0.5,
hparams=dict(
pad_val=[round(x) for x in img_norm_cfg['mean'][::-1]],
interpolation='bicubic')),
dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4),
dict(
type='RandomErasing',
erase_prob=0.25,
mode='rand',
min_area_ratio=0.02,
max_area_ratio=1 / 3,
fill_color=img_norm_cfg['mean'][::-1],
fill_std=img_norm_cfg['std'][::-1]),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='Resize',
size=(248, -1),
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=224),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]

data = dict(samples_per_gpu=128)
56 changes: 56 additions & 0 deletions configs/van/van-small-8xb128_in1k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
_base_ = [
'../_base_/models/van/van_small.py',
'../_base_/datasets/imagenet_bs64_swin_224.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py'
]

img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
size=224,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(
type='RandAugment',
policies={{_base_.rand_increasing_policies}},
num_policies=2,
total_level=10,
magnitude_level=9,
magnitude_std=0.5,
hparams=dict(
pad_val=[round(x) for x in img_norm_cfg['mean'][::-1]],
interpolation='bicubic')),
dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4),
dict(
type='RandomErasing',
erase_prob=0.25,
mode='rand',
min_area_ratio=0.02,
max_area_ratio=1 / 3,
fill_color=img_norm_cfg['mean'][::-1],
fill_std=img_norm_cfg['std'][::-1]),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='Resize',
size=(248, -1),
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=224),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]

data = dict(samples_per_gpu=128)
Loading