Skip to content

Commit

Permalink
refactor the code
Browse files Browse the repository at this point in the history
  • Loading branch information
tonysy committed Oct 11, 2022
1 parent 344aaf1 commit 5b56157
Show file tree
Hide file tree
Showing 8 changed files with 33 additions and 20 deletions.
1 change: 1 addition & 0 deletions configs/_base_/models/tinyvit/tinyvit-11m-224.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
backbone=dict(
type='TinyViT',
arch='tinyvit_11m_224',
resolution=(224, 224),
out_indices=(3, ),
drop_path_rate=0.1,
gap_before_final_norm=True,
Expand Down
1 change: 1 addition & 0 deletions configs/_base_/models/tinyvit/tinyvit-21m-224.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
backbone=dict(
type='TinyViT',
arch='tinyvit_21m_224',
resolution=(224, 224),
out_indices=(3, ),
drop_path_rate=0.2,
gap_before_final_norm=True,
Expand Down
1 change: 1 addition & 0 deletions configs/_base_/models/tinyvit/tinyvit-21m-384.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
backbone=dict(
type='TinyViT',
arch='tinyvit_21m_384',
resolution=(384, 384),
out_indices=(3, ),
drop_path_rate=0.1,
gap_before_final_norm=True,
Expand Down
1 change: 1 addition & 0 deletions configs/_base_/models/tinyvit/tinyvit-21m-512.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
backbone=dict(
type='TinyViT',
arch='tinyvit_21m_512',
resolution=(512, 512),
out_indices=(3, ),
drop_path_rate=0.1,
gap_before_final_norm=True,
Expand Down
1 change: 1 addition & 0 deletions configs/_base_/models/tinyvit/tinyvit-5m-224.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
backbone=dict(
type='TinyViT',
arch='tinyvit_5m_224',
resolution=(224, 224),
out_indices=(3, ),
drop_path_rate=0.0,
gap_before_final_norm=True,
Expand Down
34 changes: 20 additions & 14 deletions mmcls/models/backbones/tinyvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from mmengine.registry import MODELS
from torch.nn import functional as F

from ..utils import TinyViTAttention
from ..utils import LeAttention
from .base_backbone import BaseBackbone


Expand Down Expand Up @@ -93,6 +93,10 @@ class PatchEmbed(BaseModule):
Adapted from
https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py
Different from `mmcv.cnn.bricks.transformer.PatchEmbed`, this module use
Conv2d and BatchNorm2d to implement PatchEmbedding, and output shape is
(N, C, H, W).
Args:
in_channels (int): The number of input channels.
embed_dim (int): The embedding dimension.
Expand Down Expand Up @@ -129,12 +133,15 @@ def forward(self, x):
return self.seq(x)


class TinyViTPatchMerging(nn.Module):
class PatchMerging(nn.Module):
"""Patch Merging for TinyViT.
Adapted from
https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py
Different from `mmcls.models.utils.PatchMerging`, this module use Conv2d
and BatchNorm2d to implement PatchMerging.
Args:
in_channels (int): The number of input channels.
resolution (Tuple[int, int]): The resolution of the input feature.
Expand Down Expand Up @@ -398,7 +405,7 @@ def __init__(self,
head_dim = in_channels // num_heads

window_resolution = (window_size, window_size)
self.attn = TinyViTAttention(
self.attn = LeAttention(
in_channels,
head_dim,
num_heads,
Expand Down Expand Up @@ -590,41 +597,41 @@ class TinyViT(BaseBackbone):
"""
arch_settings = {
'tinyvit_5m_224': {
'resolution': (224, 224),
'channels': [64, 128, 160, 320],
'num_heads': [2, 4, 5, 10],
'window_sizes': [7, 7, 14, 7],
'depths': [2, 2, 6, 2],
},
'tinyvit_11m_224': {
'resolution': (224, 224),
'channels': [64, 128, 256, 448],
'num_heads': [2, 4, 8, 14],
'window_sizes': [7, 7, 14, 7],
'depths': [2, 2, 6, 2],
},
'tinyvit_21m_224': {
'resolution': (224, 224),
'channels': [96, 192, 384, 576],
'num_heads': [3, 6, 12, 18],
'window_sizes': [7, 7, 14, 7],
'depths': [2, 2, 6, 2],
},
'tinyvit_21m_384': {
'resolution': (384, 384),
'channels': [96, 192, 384, 576],
'num_heads': [3, 6, 12, 18],
'window_sizes': [12, 12, 24, 12],
'depths': [2, 2, 6, 2],
},
'tinyvit_21m_512': {
'resolution': (512, 512),
'channels': [96, 192, 384, 576],
'num_heads': [3, 6, 12, 18],
'window_sizes': [16, 16, 32, 16],
'depths': [2, 2, 6, 2],
}
}

def __init__(self,
arch='tinyvit_5m_224',
resolution=(224, 224),
in_channels=3,
depths=[2, 2, 6, 2],
mlp_ratio=4.,
drop_rate=0.,
drop_path_rate=0.1,
Expand All @@ -647,15 +654,15 @@ def __init__(self,
arch = self.arch_settings[arch]
elif isinstance(arch, dict):
assert 'channels' in arch and 'num_heads' in arch and \
'window_sizes' in arch and 'resolution' in arch, \
'window_sizes' in arch and 'depths' in arch, \
f'Th arch dict must have "channels", "num_heads", ' \
f'"window_sizes" keys, but got {arch.keys()}'

self.channels = arch['channels']
self.num_heads = arch['num_heads']
self.widow_sizes = arch['window_sizes']
self.resolution = arch['resolution']
self.depths = depths
self.resolution = resolution
self.depths = arch['depths']

self.num_stages = len(self.channels)

Expand Down Expand Up @@ -695,8 +702,7 @@ def __init__(self,
curr_resolution = (patches_resolution[0] // (2**i),
patches_resolution[1] // (2**i))
drop_path = dpr[sum(self.depths[:i]):sum(self.depths[:i + 1])]
downsample = TinyViTPatchMerging if (
i < self.num_stages - 1) else None
downsample = PatchMerging if (i < self.num_stages - 1) else None
out_channels = self.channels[min(i + 1, self.num_stages - 1)]
if i >= 1:
stage = BasicStage(
Expand Down
8 changes: 4 additions & 4 deletions mmcls/models/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .attention import (BEiTAttention, ChannelMultiheadAttention,
MultiheadAttention, ShiftWindowMSA, TinyViTAttention,
WindowMSA, WindowMSAV2)
from .attention import (BEiTAttention, ChannelMultiheadAttention, LeAttention,
MultiheadAttention, ShiftWindowMSA, WindowMSA,
WindowMSAV2)
from .batch_augments import CutMix, Mixup, RandomBatchAugment, ResizeMix
from .channel_shuffle import channel_shuffle
from .data_preprocessor import ClsDataPreprocessor
Expand All @@ -23,5 +23,5 @@
'resize_pos_embed', 'resize_relative_position_bias_table',
'ClsDataPreprocessor', 'Mixup', 'CutMix', 'ResizeMix', 'BEiTAttention',
'LayerScale', 'WindowMSA', 'WindowMSAV2', 'ChannelMultiheadAttention',
'PositionEncodingFourier', 'TinyViTAttention'
'PositionEncodingFourier', 'LeAttention'
]
6 changes: 4 additions & 2 deletions mmcls/models/utils/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,8 +776,10 @@ def forward(self, x):
return x


class TinyViTAttention(BaseModule):
"""TinyViT Attention.
class LeAttention(BaseModule):
"""LeViT Attention. Multi-head attention with attention bias, which is
proposed in `LeViT: a Vision Transformer in ConvNet’s Clothing for Faster
Inference<https://arxiv.org/abs/2104.01136>`_
Args:
dim (int): Number of input channels.
Expand Down

0 comments on commit 5b56157

Please sign in to comment.