Skip to content

Commit

Permalink
fix review
Browse files Browse the repository at this point in the history
  • Loading branch information
okotaku committed Apr 19, 2022
1 parent 0b4f3f7 commit 5fa73b8
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 85 deletions.
6 changes: 1 addition & 5 deletions configs/van/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,7 @@
While originally designed for natural language processing (NLP) tasks, the self-attention mechanism has recently taken various computer vision areas by storm. However, the 2D nature of images brings three challenges for applying self-attention in computer vision. (1) Treating images as 1D sequences neglects their 2D structures. (2) The quadratic complexity is too expensive for high-resolution images. (3) It only captures spatial adaptability but ignores channel adaptability. In this paper, we propose a novel large kernel attention (LKA) module to enable self-adaptive and long-range correlations in self-attention while avoiding the above issues. We further introduce a novel neural network based on LKA, namely Visual Attention Network (VAN). While extremely simple and efficient, VAN outperforms the state-of-the-art vision transformers and convolutional neural networks with a large margin in extensive experiments, including image classification, object detection, semantic segmentation, instance segmentation, etc.

<div align=center>
<img src="https://user-images.githubusercontent.com/24734142/157409484-f26fcc1f-a856-48c2-a7a7-d157c38877ac.png" width="90%"/>
</div>

<div align=center>
<img src="https://user-images.githubusercontent.com/24734142/157409411-2f622ba7-553c-4702-91be-eba03f9ea04f.png" width="90%"/>
<img src="https://user-images.githubusercontent.com/24734142/157409411-2f622ba7-553c-4702-91be-eba03f9ea04f.png" width="45%"/>
</div>


Expand Down
127 changes: 53 additions & 74 deletions mmcls/models/backbones/van.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math

import torch
import torch.nn as nn
from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer
Expand All @@ -20,6 +18,7 @@ class MixFFN(BaseModule):
The differences between MixFFN & FFN:
1. Use 1X1 Conv to replace Linear layer.
2. Introduce 3X3 Depth-wise Conv to encode positional information.
Args:
embed_dims (int): The feature dimension. Same as
`MultiheadAttention`.
Expand Down Expand Up @@ -53,7 +52,7 @@ def __init__(self,
out_channels=feedforward_channels,
kernel_size=3,
stride=1,
padding=(3 - 1) // 2,
padding=1,
bias=True,
groups=feedforward_channels)
self.act = build_activation_layer(act_cfg)
Expand All @@ -63,16 +62,6 @@ def __init__(self,
kernel_size=1)
self.drop = nn.Dropout(ffn_drop)

def init_weights(self):
super(MixFFN, self).init_weights()
for m in self.modules():
if isinstance(m, Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()

def forward(self, x):
x = self.fc1(x)
x = self.dwconv(x)
Expand All @@ -83,8 +72,13 @@ def forward(self, x):
return x


class AttentionModule(BaseModule):
"""LKA of VAN.
class LKA(BaseModule):
"""Large Kernel Attention(LKA) of VAN.
This module has three components: a spatial local convolution
(depth-wise convolution), a spatial long-range convolution
(depth-wise dilation convolution) and a channel convolution
(1×1 convolution).
Args:
embed_dims (int): Number of input channels.
Expand All @@ -93,35 +87,40 @@ class AttentionModule(BaseModule):
"""

def __init__(self, embed_dims, init_cfg=None):
super(AttentionModule, self).__init__(init_cfg=init_cfg)
self.conv0 = Conv2d(
super(LKA, self).__init__(init_cfg=init_cfg)

# a spatial local convolution (depth-wise convolution)
self.DW_conv = Conv2d(
in_channels=embed_dims,
out_channels=embed_dims,
kernel_size=5,
padding=2,
groups=embed_dims)
self.conv_spatial = Conv2d(

# a spatial long-range convolution (depth-wise dilation convolution)
self.DW_D_conv = Conv2d(
in_channels=embed_dims,
out_channels=embed_dims,
kernel_size=7,
stride=1,
padding=9,
groups=embed_dims,
dilation=3)

self.conv1 = Conv2d(
in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)

def forward(self, x):
u = x.clone()
attn = self.conv0(x)
attn = self.conv_spatial(attn)
attn = self.DW_conv(x)
attn = self.DW_D_conv(attn)
attn = self.conv1(attn)

return u * attn


class SpatialAttention(BaseModule):
"""A stage of VAN.
"""Basic attention module in VANBloack.
Args:
embed_dims (int): Number of input channels.
Expand All @@ -137,7 +136,7 @@ def __init__(self, embed_dims, act_cfg=dict(type='GELU'), init_cfg=None):
self.proj_1 = Conv2d(
in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)
self.activation = build_activation_layer(act_cfg)
self.spatial_gating_unit = AttentionModule(embed_dims)
self.spatial_gating_unit = LKA(embed_dims)
self.proj_2 = Conv2d(
in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)

Expand Down Expand Up @@ -199,16 +198,6 @@ def __init__(self,
layer_scale_init_value * torch.ones((embed_dims)),
requires_grad=True) if layer_scale_init_value > 0 else None

def init_weights(self):
super(VANBlock, self).init_weights()
for m in self.modules():
if isinstance(m, Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()

def forward(self, x):
identity = x
x = self.norm1(x)
Expand All @@ -228,7 +217,15 @@ def forward(self, x):


class VANPatchEmbed(PatchEmbed):
"""Image to Patch Embedding of VAN."""
"""Image to Patch Embedding of VAN.
The differences between VANPatchEmbed & PatchEmbed:
1. Use BN.
2. Do not use 'flatten' and 'transpose'.
"""

def __init__(self, *args, norm_cfg=dict(type='BN'), **kwargs):
super(VANPatchEmbed, self).__init__(*args, norm_cfg=norm_cfg, **kwargs)

def forward(self, x):
"""
Expand Down Expand Up @@ -297,9 +294,9 @@ class VAN(BaseBackbone):
>>> extra_config = dict(
>>> arch='tiny',
>>> block_cfgs=dict(norm_cfg=dict(type='BN', eps=1e-5)))
>>> self = VAN(**extra_config)
>>> model = VAN(**extra_config)
>>> inputs = torch.rand(1, 3, 224, 224)
>>> output = self.forward(inputs)
>>> output = model(inputs)
>>> print(output[0].shape)
(1, 256, 7, 7)
"""
Expand Down Expand Up @@ -360,6 +357,7 @@ def __init__(self,
x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
] # stochastic depth decay rule

cur_block_idx = 0
for i, depth in enumerate(self.depths):
patch_embed = VANPatchEmbed(
in_channels=in_channels if i == 0 else self.embed_dims[i - 1],
Expand All @@ -370,39 +368,21 @@ def __init__(self,
padding=(patch_sizes[i] // 2, patch_sizes[i] // 2),
norm_cfg=dict(type='BN'))

block = ModuleList([
blocks = ModuleList([
VANBlock(
embed_dims=self.embed_dims[i],
ffn_ratio=self.ffn_ratios[i],
drop_rate=drop_rate,
drop_path_rate=dpr[j],
drop_path_rate=dpr[cur_block_idx + j],
**block_cfgs) for j in range(depth)
])
cur_block_idx += depth
norm = build_norm_layer(norm_cfg, self.embed_dims[i])[1]
dpr = dpr[depth:]

self.add_module(f'patch_embed{i + 1}', patch_embed)
self.add_module(f'block{i + 1}', block)
self.add_module(f'blocks{i + 1}', blocks)
self.add_module(f'norm{i + 1}', norm)

def init_weights(self):
super(VAN, self).init_weights()
if (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
# Suppress default init if use pretrained model.
return

for m in self.modules():
if isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()

def train(self, mode=True):
super(VAN, self).train(mode)
self._freeze_stages()
Expand All @@ -413,39 +393,38 @@ def train(self, mode=True):
m.eval()

def _freeze_stages(self):
if self.frozen_stages >= 0:
self.patch_embed1.eval()
for param in self.patch_embed1.parameters():
for i in range(0, self.frozen_stages + 1):
# freeze patch embed
m = getattr(self, f'patch_embed{i + 1}')
m.eval()
for param in m.parameters():
param.requires_grad = False

for i in range(0, self.frozen_stages + 1):
if i != 0:
m = getattr(self, f'patch_embed{i + 1}')
m.eval()
for param in m.parameters():
param.requires_grad = False
# freeze blocks
m = getattr(self, f'blocks{i + 1}')
m.eval()
for param in m.parameters():
param.requires_grad = False

# freeze norm
m = getattr(self, f'norm{i + 1}')
m.eval()
for param in m.parameters():
param.requires_grad = False
for i in self.out_indices:
if i <= self.frozen_stages:
for param in getattr(self, f'norm{i + 1}').parameters():
param.requires_grad = False

def forward(self, x):
outs = []
for i in range(self.num_stages):
patch_embed = getattr(self, f'patch_embed{i + 1}')
block = getattr(self, f'block{i + 1}')
blocks = getattr(self, f'blocks{i + 1}')
norm = getattr(self, f'norm{i + 1}')
x, hw_shape = patch_embed(x)
for blk in block:
x = blk(x)
for block in blocks:
x = block(x)
x = x.flatten(2).transpose(1, 2)
x = norm(x)
x = x.reshape(-1, *hw_shape,
blk.out_channels).permute(0, 3, 1, 2).contiguous()
block.out_channels).permute(0, 3, 1, 2).contiguous()
if i in self.out_indices:
outs.append(x)

Expand Down
18 changes: 12 additions & 6 deletions tests/test_models/test_backbones/test_van.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_arch(self):
model = VAN(**cfg)

for i in range(len(depths)):
stage = getattr(model, f'block{i + 1}')
stage = getattr(model, f'blocks{i + 1}')
self.assertEqual(stage[-1].out_channels, embed_dims[i])
self.assertEqual(len(stage), depths[i])

Expand Down Expand Up @@ -129,7 +129,7 @@ def test_structure(self):
cfg['drop_path_rate'] = 0.2
model = VAN(**cfg)
depths = model.arch_settings['depths']
stages = [model.block1, model.block2, model.block3, model.block4]
stages = [model.blocks1, model.blocks2, model.blocks3, model.blocks4]
blocks = chain(*[stage for stage in stages])
total_depth = sum(depths)
dpr = [
Expand Down Expand Up @@ -165,17 +165,23 @@ def test_structure(self):
for param in model.patch_embed1.parameters():
self.assertFalse(param.requires_grad)
for i in range(frozen_stages + 1):
stage = getattr(model, f'patch_embed{i+1}')
for param in stage.parameters():
patch = getattr(model, f'patch_embed{i+1}')
for param in patch.parameters():
self.assertFalse(param.requires_grad)
blocks = getattr(model, f'blocks{i + 1}')
for param in blocks.parameters():
self.assertFalse(param.requires_grad)
norm = getattr(model, f'norm{i + 1}')
for param in norm.parameters():
self.assertFalse(param.requires_grad)

# the second stage should require grad.
for i in range(frozen_stages + 1, 4):
patch = getattr(model, f'patch_embed{i + 1}')
for param in patch.parameters():
self.assertTrue(param.requires_grad)
stage = getattr(model, f'block{i+1}')
for param in stage.parameters():
blocks = getattr(model, f'blocks{i+1}')
for param in blocks.parameters():
self.assertTrue(param.requires_grad)
norm = getattr(model, f'norm{i + 1}')
for param in norm.parameters():
Expand Down

0 comments on commit 5fa73b8

Please sign in to comment.