diff --git a/configs/_base_/models/mobilenet_v1.py b/configs/_base_/models/mobilenet_v1.py new file mode 100644 index 00000000000..9afe3953665 --- /dev/null +++ b/configs/_base_/models/mobilenet_v1.py @@ -0,0 +1,12 @@ +# model settings +model = dict( + type='ImageClassifier', + backbone=dict(type='MobileNetV1', width_mult=1.0), + neck=dict(type='GlobalAveragePooling'), + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=1024, + loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + topk=(1, 5), + )) diff --git a/mmpretrain/models/backbones/mobilenet_v1.py b/mmpretrain/models/backbones/mobilenet_v1.py new file mode 100644 index 00000000000..01afbea982c --- /dev/null +++ b/mmpretrain/models/backbones/mobilenet_v1.py @@ -0,0 +1,133 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmcv.cnn import ConvModule +from torch.nn.modules.batchnorm import _BatchNorm + +from mmpretrain.registry import MODELS +from .base_backbone import BaseBackbone + + +@MODELS.register_module() +class MobileNetV1(BaseBackbone): + """MobileNetV1 backbone for image classification. + + Args: + input_channels (int): The input channels of the image tensor. + conv_cfg (dict): Config dict for convolution layer. Default: None. + frozen_stages (int): Stages to be frozen (all param fixed). + -1 means not freezing any parameters. + Default: -1. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. + Using checkpoint will save some memory while slowing down the + training speed. + Default: False. + init_cfg (list[dict]): Initialization config dict. Default: [ + dict(type='Kaiming', layer=['Conv2d']), + dict(type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm']) + ]. + """ + + def __init__(self, + input_channels, + conv_cfg=None, + frozen_stages=-1, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + norm_eval=False, + with_cp=False, + init_cfg=[ + dict(type='Kaiming', layer=['Conv2d']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ]): + super(MobileNetV1, self).__init__(init_cfg) + self.arch_settings = [[32, 64, 1], [64, 128, 2], [128, 128, 1], + [128, 256, 2], [256, 256, 1], [256, 512, 2], + [512, 512, 1], [512, 512, 1], [512, 512, 1], + [512, 512, 1], [512, 512, 1], [512, 1024, 2], + [1024, 1024, 1]] + if frozen_stages not in range(-1, 8): + raise ValueError('frozen_stages must be in range(-1, 8). ' + f'But received {frozen_stages}') + self.in_channels = input_channels + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.norm_eval = norm_eval + self.with_cp = with_cp + self.layers = [] + + # Add the first convolution layer to layers + layer = ConvModule( + in_channels=self.in_channels, + out_channels=32, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.layers.append(layer) + + for layer_cfg in (self.arch_settings): + in_ch, out_ch, stride = layer_cfg + intermediate_layer = [] + depthwise_layer = ConvModule( + in_channels=in_ch, + out_channels=in_ch, + kernel_size=3, + stride=stride, + padding=1, + groups=in_ch, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + pointwise_layer = ConvModule( + in_channels=in_ch, + out_channels=out_ch, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + intermediate_layer = nn.Sequential(depthwise_layer, + pointwise_layer) + self.layers.append(intermediate_layer) + self.model = nn.Sequential(*self.layers) + + def forward(self, x): + outs = [] + for i, layer_name in enumerate(self.layers): + layer = getattr(self, layer_name) + x = layer(x) + if i in self.out_indices: + outs.append(x) + + return tuple(outs) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + for i in range(0, self.frozen_stages + 1): + layer = getattr(self, f'layer{i}') + layer.eval() + for param in layer.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(MobileNetV1, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() diff --git a/tests/test_models/test_backbones/test_mobilenet_v1.py b/tests/test_models/test_backbones/test_mobilenet_v1.py new file mode 100644 index 00000000000..d4ecfa97b15 --- /dev/null +++ b/tests/test_models/test_backbones/test_mobilenet_v1.py @@ -0,0 +1,115 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest +import torch +from torch.nn.modules import GroupNorm +from torch.nn.modules.batchnorm import _BatchNorm + +from mmpretrain.models.backbones import MobileNetV1 + +def is_norm(modules): + """Check if is one of the norms.""" + if isinstance(modules, (GroupNorm, _BatchNorm)): + return True + return False + + +def check_norm_state(modules, train_state): + """Check if norm layer is in correct train state.""" + for mod in modules: + if isinstance(mod, _BatchNorm): + if mod.training != train_state: + return False + return True + +def test_mobilenetv1_backbone(): + with pytest.raises(TypeError): + # pretrained must be a string path + model = MobileNetV1() + model.init_weights(pretrained=0) + + with pytest.raises(ValueError): + # frozen_stages must in range(-1, 8) + MobileNetV1(frozen_stages=8) + + + # Test MobileNetV2 with first stage frozen + frozen_stages = 1 + model = MobileNetV1(frozen_stages=frozen_stages) + model.init_weights() + model.train() + + for mod in model.modules(): + for param in mod.parameters(): + assert param.requires_grad is False + for i in range(1, frozen_stages + 1): + layer = getattr(model, f'layer{i}') + for mod in layer.modules(): + if isinstance(mod, _BatchNorm): + assert mod.training is False + for param in layer.parameters(): + assert param.requires_grad is False + + # Test MobileNetV2 with norm_eval=True + model = MobileNetV1(norm_eval=True) + model.init_weights() + model.train() + + assert check_norm_state(model.modules(), False) + + + # Test MobileNetV2 forward with dict(type='ReLU') + model = MobileNetV1(act_cfg=dict(type='ReLU')) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 7 + assert feat[0].shape == torch.Size((1, 16, 112, 112)) + assert feat[1].shape == torch.Size((1, 24, 56, 56)) + assert feat[2].shape == torch.Size((1, 32, 28, 28)) + assert feat[3].shape == torch.Size((1, 64, 14, 14)) + assert feat[4].shape == torch.Size((1, 96, 14, 14)) + assert feat[5].shape == torch.Size((1, 160, 7, 7)) + assert feat[6].shape == torch.Size((1, 320, 7, 7)) + + # Test MobileNetV2 with BatchNorm forward + model = MobileNetV1() + for m in model.modules(): + if is_norm(m): + assert isinstance(m, _BatchNorm) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 7 + assert feat[0].shape == torch.Size((1, 16, 112, 112)) + assert feat[1].shape == torch.Size((1, 24, 56, 56)) + assert feat[2].shape == torch.Size((1, 32, 28, 28)) + assert feat[3].shape == torch.Size((1, 64, 14, 14)) + assert feat[4].shape == torch.Size((1, 96, 14, 14)) + assert feat[5].shape == torch.Size((1, 160, 7, 7)) + assert feat[6].shape == torch.Size((1, 320, 7, 7)) + + # Test MobileNetV2 with GroupNorm forward + model = MobileNetV1( + norm_cfg=dict(type='GN', num_groups=2, requires_grad=True)) + for m in model.modules(): + if is_norm(m): + assert isinstance(m, GroupNorm) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 7 + assert feat[0].shape == torch.Size((1, 16, 112, 112)) + assert feat[1].shape == torch.Size((1, 24, 56, 56)) + assert feat[2].shape == torch.Size((1, 32, 28, 28)) + assert feat[3].shape == torch.Size((1, 64, 14, 14)) + assert feat[4].shape == torch.Size((1, 96, 14, 14)) + assert feat[5].shape == torch.Size((1, 160, 7, 7)) + assert feat[6].shape == torch.Size((1, 320, 7, 7)) + + \ No newline at end of file