diff --git a/mmocr/models/common/backbones/__init__.py b/mmocr/models/common/backbones/__init__.py index 3c384ba30..7e0c98ef7 100644 --- a/mmocr/models/common/backbones/__init__.py +++ b/mmocr/models/common/backbones/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .clip_resnet import CLIPResNet from .unet import UNet -__all__ = ['UNet'] +__all__ = ['UNet', 'CLIPResNet'] diff --git a/mmocr/models/common/backbones/clip_resnet.py b/mmocr/models/common/backbones/clip_resnet.py new file mode 100644 index 000000000..4de20986b --- /dev/null +++ b/mmocr/models/common/backbones/clip_resnet.py @@ -0,0 +1,100 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import torch.nn as nn +from mmdet.models.backbones import ResNet +from mmdet.models.backbones.resnet import Bottleneck + +from mmocr.registry import MODELS + + +class CLIPBottleneck(Bottleneck): + """Bottleneck for CLIPResNet. + + It is a Bottleneck variant used in the ResNet variant of CLIP. After the + second convolution layer, there is an additional average pooling layer with + kernel_size 2 and stride 2, which is added as a plugin when the + input stride > 1. The stride of each convolution layer is always set to 1. + + Args: + **kwargs: Keyword arguments for + :class:``mmdet.models.backbones.resnet.Bottleneck``. + """ + + def __init__(self, **kwargs): + stride = kwargs.get('stride', 1) + kwargs['stride'] = 1 + plugins = kwargs.get('plugins', None) + if stride > 1: + if plugins is None: + plugins = [] + + plugins.insert( + 0, + dict( + cfg=dict(type='mmocr.AvgPool2d', kernel_size=2), + position='after_conv2')) + kwargs['plugins'] = plugins + super().__init__(**kwargs) + + +@MODELS.register_module() +class CLIPResNet(ResNet): + """Implement the ResNet variant used in `oCLIP. + + `_. + + It is also the official structure in + `CLIP `_. + + Compared with ResNetV1d structure, CLIPResNet replaces the + max pooling layer with an average pooling layer at the end + of the input stem. + + In the Bottleneck of CLIPResNet, after the second convolution + layer, there is an additional average pooling layer with + kernel_size 2 and stride 2, which is added as a plugin + when the input stride > 1. + The stride of each convolution layer is always set to 1. + + Args: + depth (int): Depth of resnet, options are [50]. Defaults to 50. + strides (sequence(int)): Strides of the first block of each stage. + Defaults to (1, 2, 2, 2). + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. + Defaults to True. + avg_down (bool): Use AvgPool instead of stride conv at + the downsampling stage in the bottleneck. Defaults to True. + **kwargs: Keyword arguments for + :class:``mmdet.models.backbones.resnet.ResNet``. + """ + arch_settings = { + 50: (CLIPBottleneck, (3, 4, 6, 3)), + } + + def __init__(self, + depth=50, + strides=(1, 2, 2, 2), + deep_stem=True, + avg_down=True, + **kwargs): + super().__init__( + depth=depth, + strides=strides, + deep_stem=deep_stem, + avg_down=avg_down, + **kwargs) + + def _make_stem_layer(self, in_channels: int, stem_channels: int): + """Build stem layer for CLIPResNet used in `CLIP + https://github.com/openai/CLIP>`_. + + It uses an average pooling layer rather than a max pooling + layer at the end of the input stem. + + Args: + in_channels (int): Number of input channels. + stem_channels (int): Number of output channels. + """ + super()._make_stem_layer(in_channels, stem_channels) + if self.deep_stem: + self.maxpool = nn.AvgPool2d(kernel_size=2) diff --git a/mmocr/models/common/plugins/__init__.py b/mmocr/models/common/plugins/__init__.py new file mode 100644 index 000000000..1ad4c93c0 --- /dev/null +++ b/mmocr/models/common/plugins/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .common import AvgPool2d + +__all__ = ['AvgPool2d'] diff --git a/mmocr/models/common/plugins/common.py b/mmocr/models/common/plugins/common.py new file mode 100644 index 000000000..722b53f56 --- /dev/null +++ b/mmocr/models/common/plugins/common.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from mmocr.registry import MODELS + + +@MODELS.register_module() +class AvgPool2d(nn.Module): + """Applies a 2D average pooling over an input signal composed of several + input planes. + + It can also be used as a network plugin. + + Args: + kernel_size (int or tuple(int)): the size of the window. + stride (int or tuple(int), optional): the stride of the window. + Defaults to None. + padding (int or tuple(int)): implicit zero padding. Defaults to 0. + """ + + def __init__(self, + kernel_size: Union[int, Tuple[int]], + stride: Optional[Union[int, Tuple[int]]] = None, + padding: Union[int, Tuple[int]] = 0, + **kwargs) -> None: + super().__init__() + self.model = nn.AvgPool2d(kernel_size, stride, padding) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function. + Args: + x (Tensor): Input feature map. + + Returns: + Tensor: Output tensor after Avgpooling layer. + """ + return self.model(x) diff --git a/tests/test_models/test_common/test_backbones/test_clip_resnet.py b/tests/test_models/test_common/test_backbones/test_clip_resnet.py new file mode 100644 index 000000000..fd71395f8 --- /dev/null +++ b/tests/test_models/test_common/test_backbones/test_clip_resnet.py @@ -0,0 +1,66 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch +import torch.nn as nn +from mmcv.cnn import build_conv_layer, build_norm_layer + +from mmocr.models.common.backbones import CLIPResNet +from mmocr.models.common.backbones.clip_resnet import CLIPBottleneck + + +class TestCLIPResNet(TestCase): + + def test_forward(self): + model = CLIPResNet() + model.eval() + + imgs = torch.randn(1, 3, 32, 32) + feat = model(imgs) + assert len(feat) == 4 + assert feat[0].shape == torch.Size([1, 256, 8, 8]) + assert feat[1].shape == torch.Size([1, 512, 4, 4]) + assert feat[2].shape == torch.Size([1, 1024, 2, 2]) + assert feat[3].shape == torch.Size([1, 2048, 1, 1]) + + +class TestCLIPBottleneck(TestCase): + + def test_forward(self): + stride = 2 + inplanes = 256 + planes = 128 + conv_cfg = None + norm_cfg = {'type': 'BN', 'requires_grad': True} + + downsample = [] + downsample.append( + nn.AvgPool2d( + kernel_size=stride, + stride=stride, + ceil_mode=True, + count_include_pad=False)) + downsample.extend([ + build_conv_layer( + conv_cfg, + inplanes, + planes * CLIPBottleneck.expansion, + kernel_size=1, + stride=1, + bias=False), + build_norm_layer(norm_cfg, planes * CLIPBottleneck.expansion)[1] + ]) + downsample = nn.Sequential(*downsample) + + model = CLIPBottleneck( + inplanes=inplanes, + planes=planes, + stride=stride, + downsample=downsample, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg) + model.eval() + + input_feat = torch.randn(1, 256, 8, 8) + output_feat = model(input_feat) + assert output_feat.shape == torch.Size([1, 512, 4, 4]) diff --git a/tests/test_models/test_common/test_plugins/test_avgpool.py b/tests/test_models/test_common/test_plugins/test_avgpool.py new file mode 100644 index 000000000..766ddf5d6 --- /dev/null +++ b/tests/test_models/test_common/test_plugins/test_avgpool.py @@ -0,0 +1,16 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch + +from mmocr.models.common.plugins import AvgPool2d + + +class TestAvgPool2d(TestCase): + + def setUp(self) -> None: + self.img = torch.rand(1, 3, 32, 100) + + def test_avgpool2d(self): + avgpool2d = AvgPool2d(kernel_size=2, stride=2) + self.assertEqual(avgpool2d(self.img).shape, torch.Size([1, 3, 16, 50]))