Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mobile net v1 #1545

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 12 additions & 0 deletions configs/_base_/models/mobilenet_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(type='MobileNetV1', width_mult=1.0),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=1024,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))
133 changes: 133 additions & 0 deletions mmpretrain/models/backbones/mobilenet_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn import ConvModule
from torch.nn.modules.batchnorm import _BatchNorm

from mmpretrain.registry import MODELS
from .base_backbone import BaseBackbone


@MODELS.register_module()
class MobileNetV1(BaseBackbone):
"""MobileNetV1 backbone for image classification.

Args:
input_channels (int): The input channels of the image tensor.
conv_cfg (dict): Config dict for convolution layer. Default: None.
frozen_stages (int): Stages to be frozen (all param fixed).
-1 means not freezing any parameters.
Default: -1.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not.
Using checkpoint will save some memory while slowing down the
training speed.
Default: False.
init_cfg (list[dict]): Initialization config dict. Default: [
dict(type='Kaiming', layer=['Conv2d']),
dict(type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm'])
].
"""

def __init__(self,
input_channels,
conv_cfg=None,
frozen_stages=-1,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
norm_eval=False,
with_cp=False,
init_cfg=[
dict(type='Kaiming', layer=['Conv2d']),
dict(
type='Constant',
val=1,
layer=['_BatchNorm', 'GroupNorm'])
]):
super(MobileNetV1, self).__init__(init_cfg)
self.arch_settings = [[32, 64, 1], [64, 128, 2], [128, 128, 1],
[128, 256, 2], [256, 256, 1], [256, 512, 2],
[512, 512, 1], [512, 512, 1], [512, 512, 1],
[512, 512, 1], [512, 512, 1], [512, 1024, 2],
[1024, 1024, 1]]
if frozen_stages not in range(-1, 8):
raise ValueError('frozen_stages must be in range(-1, 8). '
f'But received {frozen_stages}')
self.in_channels = input_channels
self.frozen_stages = frozen_stages
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.norm_eval = norm_eval
self.with_cp = with_cp
self.layers = []

# Add the first convolution layer to layers
layer = ConvModule(
in_channels=self.in_channels,
out_channels=32,
kernel_size=3,
stride=2,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.layers.append(layer)

for layer_cfg in (self.arch_settings):
in_ch, out_ch, stride = layer_cfg
intermediate_layer = []
depthwise_layer = ConvModule(
in_channels=in_ch,
out_channels=in_ch,
kernel_size=3,
stride=stride,
padding=1,
groups=in_ch,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
pointwise_layer = ConvModule(
in_channels=in_ch,
out_channels=out_ch,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
intermediate_layer = nn.Sequential(depthwise_layer,
pointwise_layer)
self.layers.append(intermediate_layer)
self.model = nn.Sequential(*self.layers)

def forward(self, x):
outs = []
for i, layer_name in enumerate(self.layers):
layer = getattr(self, layer_name)
x = layer(x)
if i in self.out_indices:
outs.append(x)

return tuple(outs)

def _freeze_stages(self):
if self.frozen_stages >= 0:
for i in range(0, self.frozen_stages + 1):
layer = getattr(self, f'layer{i}')
layer.eval()
for param in layer.parameters():
param.requires_grad = False

def train(self, mode=True):
super(MobileNetV1, self).train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
if isinstance(m, _BatchNorm):
m.eval()
115 changes: 115 additions & 0 deletions tests/test_models/test_backbones/test_mobilenet_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from torch.nn.modules import GroupNorm
from torch.nn.modules.batchnorm import _BatchNorm

from mmpretrain.models.backbones import MobileNetV1

def is_norm(modules):
"""Check if is one of the norms."""
if isinstance(modules, (GroupNorm, _BatchNorm)):
return True
return False


def check_norm_state(modules, train_state):
"""Check if norm layer is in correct train state."""
for mod in modules:
if isinstance(mod, _BatchNorm):
if mod.training != train_state:
return False
return True

def test_mobilenetv1_backbone():
with pytest.raises(TypeError):
# pretrained must be a string path
model = MobileNetV1()
model.init_weights(pretrained=0)

with pytest.raises(ValueError):
# frozen_stages must in range(-1, 8)
MobileNetV1(frozen_stages=8)


# Test MobileNetV2 with first stage frozen
frozen_stages = 1
model = MobileNetV1(frozen_stages=frozen_stages)
model.init_weights()
model.train()

for mod in model.modules():
for param in mod.parameters():
assert param.requires_grad is False
for i in range(1, frozen_stages + 1):
layer = getattr(model, f'layer{i}')
for mod in layer.modules():
if isinstance(mod, _BatchNorm):
assert mod.training is False
for param in layer.parameters():
assert param.requires_grad is False

# Test MobileNetV2 with norm_eval=True
model = MobileNetV1(norm_eval=True)
model.init_weights()
model.train()

assert check_norm_state(model.modules(), False)


# Test MobileNetV2 forward with dict(type='ReLU')
model = MobileNetV1(act_cfg=dict(type='ReLU'))
model.init_weights()
model.train()

imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 7
assert feat[0].shape == torch.Size((1, 16, 112, 112))
assert feat[1].shape == torch.Size((1, 24, 56, 56))
assert feat[2].shape == torch.Size((1, 32, 28, 28))
assert feat[3].shape == torch.Size((1, 64, 14, 14))
assert feat[4].shape == torch.Size((1, 96, 14, 14))
assert feat[5].shape == torch.Size((1, 160, 7, 7))
assert feat[6].shape == torch.Size((1, 320, 7, 7))

# Test MobileNetV2 with BatchNorm forward
model = MobileNetV1()
for m in model.modules():
if is_norm(m):
assert isinstance(m, _BatchNorm)
model.init_weights()
model.train()

imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 7
assert feat[0].shape == torch.Size((1, 16, 112, 112))
assert feat[1].shape == torch.Size((1, 24, 56, 56))
assert feat[2].shape == torch.Size((1, 32, 28, 28))
assert feat[3].shape == torch.Size((1, 64, 14, 14))
assert feat[4].shape == torch.Size((1, 96, 14, 14))
assert feat[5].shape == torch.Size((1, 160, 7, 7))
assert feat[6].shape == torch.Size((1, 320, 7, 7))

# Test MobileNetV2 with GroupNorm forward
model = MobileNetV1(
norm_cfg=dict(type='GN', num_groups=2, requires_grad=True))
for m in model.modules():
if is_norm(m):
assert isinstance(m, GroupNorm)
model.init_weights()
model.train()

imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 7
assert feat[0].shape == torch.Size((1, 16, 112, 112))
assert feat[1].shape == torch.Size((1, 24, 56, 56))
assert feat[2].shape == torch.Size((1, 32, 28, 28))
assert feat[3].shape == torch.Size((1, 64, 14, 14))
assert feat[4].shape == torch.Size((1, 96, 14, 14))
assert feat[5].shape == torch.Size((1, 160, 7, 7))
assert feat[6].shape == torch.Size((1, 320, 7, 7))